示例#1
0
class SatelliteMisfitResult(gf.Result, MisfitResult):
    """Carries the observations for a target and corresponding synthetics."""
    statics_syn = Dict.T(
        optional=True,
        help='Predicted static displacements for a target (synthetics).')
    statics_obs = Dict.T(optional=True,
                         help='Observed static displacement for a target.')
示例#2
0
class CatalogConfig(Object):
    search_events = Bool.T()
    use_local_catalog = Bool.T()
    subset_of_local_catalog = Bool.T()
    use_local_subsets = Bool.T()
    subset_fns = Dict.T()
    # catalog
    catalog_fn = String.T(
        optional=True, help='Name of local catalog to save to or to read from')
    dist = Float.T(default=165.0,
                   help='Max. distance to show on catalog plot (map) [deg]')
    min_mag = Float.T()
    max_mag = Float.T()
    tmin_str = String.T()
    tmax_str = String.T()
    wedges_width = Int.T(
        help='Find one event in backazimuthal wedge of this width')
    mid_point = List.T(
        optional=True,
        help='Enter some estimation on centre point of array, needed for\
             subset based on distances and plotting only.')
    median_ev_in_bin = Bool.T(default=True)
    # weighted_magn_baz_ev = Bool.T()

    min_dist_km = Float.T()
    max_dist_km = Float.T()
    depth_options = Dict.T()

    plot_catalog_all = Bool.T(default=False)
    plot_hist_wedges = Bool.T(default=False)
    plot_wedges_vs_dist = Bool.T(default=False)
    plot_wedges_vs_magn = Bool.T(default=False)
    plot_dist_vs_magn = Bool.T(default=False)
    plot_catalog_subset = Bool.T(default=False)
示例#3
0
class SatelliteMisfitResult(gf.Result, MisfitResult):
    """Carries the observations for a target and corresponding synthetics."""
    statics_syn = Dict.T(
        String.T(),
        Array.T(dtype=num.float, shape=(None,), serialize_as='base64'),
        optional=True,
        help='Predicted static displacements for a target (synthetics).')
    statics_obs = Dict.T(
        String.T(),
        Array.T(dtype=num.float, shape=(None,), serialize_as='base64'),
        optional=True,
        help='Observed static displacement for a target.')
示例#4
0
class PlotOptions(Object):

    post_llk = String.T(
        default='max',
        help='Which model to plot on the specified plot; options:'
             ' "max", "min", "mean"')
    plot_projection = String.T(
        default='utm',
        help='Projection to use for plotting geodetic data; options: "latlon"')
    utm_zone = Int.T(
        default=36,
        optional=True,
        help='Only relevant if plot_projection is "utm"')
    load_stage = String.T(
        default='final',
        help='Which ATMCMC stage to select for plotting')
    figure_dir = String.T(
        default='figures',
        help='Name of the output directory of plots')
    reference = Dict.T(
        help='Reference point for example from a synthetic test.',
        optional=True)
    outformat = String.T(default='pdf')
    dpi = Int.T(default=300)
    force = Bool.T(default=False)
示例#5
0
class VolumePointProblemConfig(ProblemConfig):
    ranges = Dict.T(String.T(), gf.Range.T())
    distance_min = Float.T(default=0.0)
    nthreads = Int.T(default=1)

    def get_problem(self, event, target_groups, targets):
        if event.depth is None:
            event.depth = 0.

        base_source = gf.ExplosionSource.from_pyrocko_event(event)
        base_source.stf = gf.HalfSinusoidSTF(duration=event.duration or 0.0)

        subs = dict(event_name=event.name,
                    event_time=util.time_to_str(event.time))

        problem = VolumePointProblem(name=expand_template(
            self.name_template, subs),
                                     base_source=base_source,
                                     target_groups=target_groups,
                                     targets=targets,
                                     ranges=self.ranges,
                                     distance_min=self.distance_min,
                                     norm_exponent=self.norm_exponent,
                                     nthreads=self.nthreads)

        return problem
示例#6
0
class RectangularProblemConfig(ProblemConfig):

    ranges = Dict.T(String.T(), gf.Range.T())
    decimation_factor = Int.T(default=1)
    distance_min = Float.T(default=0.)
    nthreads = Int.T(default=4)

    def get_problem(self, event, target_groups, targets):
        if self.decimation_factor != 1:
            logger.warn(
                'Decimation factor for rectangular source set to %i. Results '
                'may be inaccurate.' % self.decimation_factor)

        base_source = gf.RectangularSource.from_pyrocko_event(
            event, anchor='top', decimation_factor=self.decimation_factor)

        subs = dict(event_name=event.name,
                    event_time=util.time_to_str(event.time))

        problem = RectangularProblem(name=expand_template(
            self.name_template, subs),
                                     base_source=base_source,
                                     distance_min=self.distance_min,
                                     target_groups=target_groups,
                                     targets=targets,
                                     ranges=self.ranges,
                                     norm_exponent=self.norm_exponent,
                                     nthreads=self.nthreads)

        return problem
示例#7
0
class CMTProblemConfig(ProblemConfig):

    ranges = Dict.T(String.T(), gf.Range.T())
    distance_min = Float.T(default=0.0)
    mt_type = StringChoice.T(choices=['full', 'deviatoric'])

    def get_problem(self, event, target_groups, targets):
        if event.depth is None:
            event.depth = 0.

        base_source = gf.MTSource.from_pyrocko_event(event)
        base_source.stf = gf.HalfSinusoidSTF(duration=event.duration or 0.0)

        subs = dict(event_name=event.name,
                    event_time=util.time_to_str(event.time))

        problem = CMTProblem(name=expand_template(self.name_template, subs),
                             base_source=base_source,
                             target_groups=target_groups,
                             targets=targets,
                             ranges=self.ranges,
                             distance_min=self.distance_min,
                             mt_type=self.mt_type,
                             norm_exponent=self.norm_exponent)

        return problem
示例#8
0
class PlotGroup(Object):
    name = StringID.T(
        help='group name')
    section = StringID.T(
        optional=True,
        help='group\'s section path, e.g. results.waveforms')
    title = Unicode.T(
        optional=True,
        help='group\'s title')
    description = Unicode.T(
        optional=True,
        help='group description')
    formats = List.T(
        PlotFormat.T(),
        help='plot format')
    variant = StringID.T(
        help='variant of the group')
    feather_icon = String.T(
        default='bar-chart-2',
        help='Feather icon for the HTML report.')
    size_cm = Tuple.T(2, Float.T())
    items = List.T(PlotItem.T())
    attributes = Dict.T(StringID.T(), List.T(String.T()))

    def filename_image(self, item, format):
        return '%s.%s.%s.%s' % (
            self.name,
            self.variant,
            item.name,
            format.extension)
示例#9
0
class CMTProblemConfig(ProblemConfig):

    ranges = Dict.T(String.T(), gf.Range.T())
    distance_min = Float.T(default=0.0)
    mt_type = MTType.T(default='full')
    stf_type = STFType.T(default='HalfSinusoidSTF')
    nthreads = Int.T(default=1)

    def get_problem(self, event, target_groups, targets):
        if event.depth is None:
            event.depth = 0.

        base_source = gf.MTSource.from_pyrocko_event(event)

        stf = STFType.base_stf(self.stf_type)
        stf.duration = event.duration or 0.0

        base_source.stf = stf

        subs = dict(event_name=event.name,
                    event_time=util.time_to_str(event.time))

        problem = CMTProblem(name=expand_template(self.name_template, subs),
                             base_source=base_source,
                             target_groups=target_groups,
                             targets=targets,
                             ranges=self.ranges,
                             distance_min=self.distance_min,
                             mt_type=self.mt_type,
                             stf_type=self.stf_type,
                             norm_exponent=self.norm_exponent,
                             nthreads=self.nthreads)

        return problem
示例#10
0
class PlotItem(Object):
    name = StringID.T()
    attributes = Dict.T(
        StringID.T(), List.T(String.T()))
    title = Unicode.T(
        optional=True,
        help='item\'s description')
    description = Unicode.T(
        optional=True,
        help='item\'s description')
示例#11
0
class SnufflerConfig(ConfigBase):
    visible_length_setting = List.T(
        VisibleLengthSetting.T(),
        default=[VisibleLengthSetting(key='Short', value=6000.),
                 VisibleLengthSetting(key='Medium', value=20000.),
                 VisibleLengthSetting(key='Long', value=60000.)])
    phase_key_mapping = Dict.T(
        String.T(), String.T(), default=default_phase_key_mapping)

    def get_phase_name(self, key):
        return self.phase_key_mapping.get('F%s' % key, 'Undefined')
示例#12
0
class TestResult(Object):
    package = String.T()
    branch = String.T(optional=True)
    box = String.T()
    py_version = String.T(optional=True)
    prerequisite_versions = Dict.T(String.T(),
                                   String.T(),
                                   optional=True,
                                   default={})
    log = String.T(optional=True, yamlstyle='|')
    result = String.T(optional=True)
    errors = List.T(String.T(), optional=True, default=[], yamlstyle='block')
    fails = List.T(String.T(), optional=True, default=[], yamlstyle='block')
    skips = List.T(String.T(), optional=True, default=[], yamlstyle='block')
示例#13
0
class SatelliteMisfitConfig(MisfitConfig):
    """Carries the misfit configuration."""
    optimise_orbital_ramp = Bool.T(
        default=True,
        help='Switch to account for a linear orbital ramp or not')
    ranges = Dict.T(
        String.T(), gf.Range.T(),
        default={'offset': '-0.5 .. 0.5',
                 'ramp_east': '-1e-7 .. 1e-7',
                 'ramp_north': '-1e-7 .. 1e-7'
                 },
        help='These parameters give bounds for an offset [m], a linear'
             ' gradient in east direction [m/m] and a linear gradient in north'
             ' direction [m/m]. Note, while the optimisation of these ramps'
             ' is individual for each target, the ranges set here are common'
             ' for all satellite targets.')
示例#14
0
class MetaDataDownloadConfig(Object):
    # Data and Metadata download
    download_data = Bool.T()
    download_metadata = Bool.T()
    local_metadata = List.T(default=[])
    local_data = List.T(default=[])
    local_waveforms_only = Bool.T(default=False)
    sds_structure = Bool.T(default=False)
    use_downmeta = Bool.T(default=True)
    channels_download = String.T()
    all_channels = Bool.T(default=False)
    # components = List.T()
    token = Dict.T(default={})
    sites = List.T()
    # start and end time with respect to origin time
    # start time before origin time!
    dt_start = Float.T()
    dt_end = Float.T()
示例#15
0
class MapParameters(Object):
    width = Float.T(help="width of map", optional=True, default=16)
    height = Float.T(help="heigth of map", optional=True, default=16)
    lat = Float.T(help="center latitude")
    lon = Float.T(help="center longitude")
    radius = Float.T(help="radius of map")
    outfn = String.T(help="output filename", optional=True, default="map.png")
    stations = List.T(model.Station.T(), optional=True)
    station_label_mapping = List.T(String.T(), optional=True)
    station_colors = Dict.T(String.T(), String.T(), optional=True)
    events = List.T(model.Event.T(), optional=True)
    show_topo = Bool.T(optional=True, default=True)
    show_grid = Bool.T(optional=True, default=True)
    color_wet = Tuple.T(3, Int.T(), default=(200, 200, 200))
    color_dry = Tuple.T(3, Int.T(), default=(253, 253, 253))

    @classmethod
    def process_args(cls, *args, **kwargs):
        return cls(**kwargs)
示例#16
0
class Meta(Object):
    """Meta configuration for ``Scene``."""

    scene_title = String.T(default="Unnamed Scene", help="Scene title")
    scene_id = String.T(default="None", help="Scene identification")
    satellite_name = String.T(
        default="Undefined Mission", help="Satellite mission name"
    )
    wavelength = Float.T(optional=True, help="Wavelength in [m]")
    orbital_node = StringChoice.T(
        choices=["Ascending", "Descending", "Undefined"],
        default="Undefined",
        help="Orbital direction, ascending/descending",
    )
    time_master = Timestamp.T(
        default=1481116161.930574, help="Timestamp for master acquisition"
    )
    time_slave = Timestamp.T(
        default=1482239325.482, help="Timestamp for slave acquisition"
    )
    extra = Dict.T(default={}, help="Extra header information")
    filename = String.T(optional=True)

    def __init__(self, *args, **kwargs):
        self.old_import = False

        mapping = {"orbit_direction": "orbital_node"}

        for old, new in mapping.items():
            if old in kwargs.keys():
                kwargs[new] = kwargs.pop(old, None)
                self.old_import = True

        Object.__init__(self, *args, **kwargs)

    @property
    def time_separation(self):
        """
        :getter: Absolute time difference between ``time_master``
                 and ``time_slave``
        :type: timedelta
        """
        return dt.fromtimestamp(self.time_slave) - dt.fromtimestamp(self.time_master)
示例#17
0
class RectangularProblemConfig(ProblemConfig):

    ranges = Dict.T(String.T(), gf.Range.T())
    decimation_factor = Int.T(default=1)
    distance_min = Float.T(default=0.)

    def get_problem(self, event, target_groups, targets):
        base_source = gf.RectangularSource.from_pyrocko_event(
            event, anchor='top', decimation_factor=self.decimation_factor)

        subs = dict(event_name=event.name,
                    event_time=util.time_to_str(event.time))

        problem = RectangularProblem(name=expand_template(
            self.name_template, subs),
                                     base_source=base_source,
                                     distance_min=self.distance_min,
                                     target_groups=target_groups,
                                     targets=targets,
                                     ranges=self.ranges,
                                     norm_exponent=self.norm_exponent)

        return problem
示例#18
0
文件: target.py 项目: mfkiwl/grond
class StationDictStoreIDSelector(StoreIDSelector):
    '''
    Store ID selector using a manual station to store ID mapping.
    '''

    mapping = Dict.T(
        String.T(),
        gf.StringID.T(),
        help='Dictionary with station to store ID pairs, keys are NET.STA. '
        "Add a fallback store ID under the key ``'others'``.")

    def get_store_id(self, event, st, cha):
        try:
            store_id = self.mapping['%s.%s' % (st.network, st.station)]
        except KeyError:
            try:
                store_id = self.mapping['others']
            except KeyError:
                raise StoreIDSelectorError(
                    'No store ID found for station "%s.%s".' %
                    (st.network, st.station))

        return store_id
示例#19
0
class GainfactorsConfig(Object):
    # Gainfactors
    calc_gainfactors = Bool.T()
    gain_factor_method = List.T()
    ''' gain relative to options:
       * 'scale_one' for reference amplitude = 1.
       *  tuple ('reference_nsl', refernce_id) - gain
          relative to one specific station,
          reference_id e.g. 'BFO'
       * 'median_all_avail': first assesses all abs. amplitudes,
          in the end chooses a reference station of those that recorded
          most events based on median gain of first event
    '''
    fband = Dict.T()
    taper_xfrac = Float.T()
    wdw_st_arr = Int.T()
    wdw_sp_arr = Int.T()
    snr_thresh = Float.T(default=1.)
    debug_mode = Bool.T(default=False)
    phase_select = String.T()
    components = List.T()

    plot_median_gain_on_map = Bool.T()
    plot_allgains = Bool.T()
示例#20
0
class Problem(Object):
    '''
    Base class for objective function setup.

    Defines the *problem* to be solved by the optimiser.
    '''
    name = String.T()
    ranges = Dict.T(String.T(), gf.Range.T())
    dependants = List.T(Parameter.T())
    norm_exponent = Int.T(default=2)
    base_source = gf.Source.T(optional=True)
    targets = List.T(MisfitTarget.T())
    target_groups = List.T(TargetGroup.T())
    grond_version = String.T(optional=True)
    nthreads = Int.T(default=1)

    def __init__(self, **kwargs):
        Object.__init__(self, **kwargs)

        if self.grond_version is None:
            self.grond_version = __version__

        self._target_weights = None
        self._engine = None
        self._family_mask = None

        if hasattr(self, 'problem_waveform_parameters') and self.has_waveforms:
            self.problem_parameters =\
                self.problem_parameters + self.problem_waveform_parameters

        self.check()

    @classmethod
    def get_plot_classes(cls):
        from . import plot
        return plot.get_plot_classes()

    def check(self):
        paths = set()
        for grp in self.target_groups:
            if grp.path == 'all':
                continue
            if grp.path in paths:
                raise ValueError('Path %s defined more than once! In %s' %
                                 (grp.path, grp.__class__.__name__))
            paths.add(grp.path)
        logger.debug('TargetGroup check OK.')

    def copy(self):
        o = copy.copy(self)
        o._target_weights = None
        return o

    def set_target_parameter_values(self, x):
        nprob = len(self.problem_parameters)
        for target in self.targets:
            target.set_parameter_values(x[nprob:nprob + target.nparameters])
            nprob += target.nparameters

    def get_parameter_dict(self, model, group=None):
        params = []
        for ip, p in enumerate(self.parameters):
            if group in p.groups or group is None:
                params.append((p.name, model[ip]))
        return ADict(params)

    def get_parameter_array(self, d):
        arr = num.zeros(self.nparameters, dtype=num.float)
        for ip, p in enumerate(self.parameters):
            if p.name in d.keys():
                arr[ip] = d[p.name]
        return arr

    def dump_problem_info(self, dirname):
        fn = op.join(dirname, 'problem.yaml')
        util.ensuredirs(fn)
        guts.dump(self, filename=fn)

    def dump_problem_data(self,
                          dirname,
                          x,
                          misfits,
                          bootstraps=None,
                          sampler_context=None):

        fn = op.join(dirname, 'models')
        if not isinstance(x, num.ndarray):
            x = num.array(x)
        with open(fn, 'ab') as f:
            x.astype('<f8').tofile(f)

        fn = op.join(dirname, 'misfits')
        with open(fn, 'ab') as f:
            misfits.astype('<f8').tofile(f)

        if bootstraps is not None:
            fn = op.join(dirname, 'bootstraps')
            with open(fn, 'ab') as f:
                bootstraps.astype('<f8').tofile(f)

        if sampler_context is not None:
            fn = op.join(dirname, 'choices')
            with open(fn, 'ab') as f:
                num.array(sampler_context, dtype='<i8').tofile(f)

    def name_to_index(self, name):
        pnames = [p.name for p in self.combined]
        return pnames.index(name)

    @property
    def parameters(self):
        target_parameters = []
        for target in self.targets:
            target_parameters.extend(target.target_parameters)
        return self.problem_parameters + target_parameters

    @property
    def parameter_names(self):
        return [p.name for p in self.combined]

    @property
    def dependant_names(self):
        return [p.name for p in self.dependants]

    @property
    def nparameters(self):
        return len(self.parameters)

    @property
    def ntargets(self):
        return len(self.targets)

    @property
    def nwaveform_targets(self):
        return len(self.waveform_targets)

    @property
    def nsatellite_targets(self):
        return len(self.satellite_targets)

    @property
    def ngnss_targets(self):
        return len(self.gnss_targets)

    @property
    def nmisfits(self):
        nmisfits = 0
        for target in self.targets:
            nmisfits += target.nmisfits
        return nmisfits

    @property
    def ndependants(self):
        return len(self.dependants)

    @property
    def ncombined(self):
        return len(self.parameters) + len(self.dependants)

    @property
    def combined(self):
        return self.parameters + self.dependants

    @property
    def satellite_targets(self):
        return [
            t for t in self.targets if isinstance(t, SatelliteMisfitTarget)
        ]

    @property
    def gnss_targets(self):
        return [
            t for t in self.targets if isinstance(t, GNSSCampaignMisfitTarget)
        ]

    @property
    def waveform_targets(self):
        return [t for t in self.targets if isinstance(t, WaveformMisfitTarget)]

    @property
    def has_satellite(self):
        if self.satellite_targets:
            return True
        return False

    @property
    def has_waveforms(self):
        if self.waveform_targets:
            return True
        return False

    def set_engine(self, engine):
        self._engine = engine

    def get_engine(self):
        return self._engine

    def get_gf_store(self, target):
        if self.get_engine() is None:
            raise GrondError('Cannot get GF Store, modelling is not set up!')
        return self.get_engine().get_store(target.store_id)

    def random_uniform(self, xbounds, rstate):
        x = rstate.uniform(0., 1., self.nparameters)
        x *= (xbounds[:, 1] - xbounds[:, 0])
        x += xbounds[:, 0]
        return x

    def preconstrain(self, x):
        return x

    def extract(self, xs, i):
        if xs.ndim == 1:
            return self.extract(xs[num.newaxis, :], i)[0]

        if i < self.nparameters:
            return xs[:, i]
        else:
            return self.make_dependant(
                xs, self.dependants[i - self.nparameters].name)

    def get_target_weights(self):
        if self._target_weights is None:
            self._target_weights = num.concatenate(
                [target.get_combined_weight() for target in self.targets])

        return self._target_weights

    def get_target_residuals(self):
        pass

    def inter_family_weights(self, ns):
        exp, root = self.get_norm_functions()

        family, nfamilies = self.get_family_mask()

        ws = num.zeros(self.nmisfits)
        for ifamily in range(nfamilies):
            mask = family == ifamily
            ws[mask] = 1.0 / root(num.nansum(exp(ns[mask])))

        return ws

    def inter_family_weights2(self, ns):
        '''
        :param ns: 2D array with normalization factors ``ns[imodel, itarget]``
        :returns: 2D array ``weights[imodel, itarget]``
        '''

        exp, root = self.get_norm_functions()

        family, nfamilies = self.get_family_mask()
        ws = num.zeros(ns.shape)
        for ifamily in range(nfamilies):
            mask = family == ifamily
            ws[:, mask] = (
                1.0 / root(num.nansum(exp(ns[:, mask]), axis=1)))[:,
                                                                  num.newaxis]
        return ws

    def get_reference_model(self, expand=False):
        if expand:
            src_params = self.pack(self.base_source)
            ref = num.zeros(self.nparameters)
            ref[:src_params.size] = src_params
        else:
            ref = self.pack(self.base_source)
        return ref

    def get_parameter_bounds(self):
        out = []
        for p in self.problem_parameters:
            r = self.ranges[p.name]
            out.append((r.start, r.stop))

        for target in self.targets:
            for p in target.target_parameters:
                r = target.target_ranges[p.name_nogroups]
                out.append((r.start, r.stop))

        return num.array(out, dtype=num.float)

    def get_dependant_bounds(self):
        return num.zeros((0, 2))

    def get_combined_bounds(self):
        return num.vstack(
            (self.get_parameter_bounds(), self.get_dependant_bounds()))

    def raise_invalid_norm_exponent(self):
        raise GrondError('Invalid norm exponent: %f' % self.norm_exponent)

    def get_norm_functions(self):
        if self.norm_exponent == 2:

            def sqr(x):
                return x**2

            return sqr, num.sqrt

        elif self.norm_exponent == 1:

            def noop(x):
                return x

            return noop, num.abs

        else:
            self.raise_invalid_norm_exponent()

    def combine_misfits(self,
                        misfits,
                        extra_weights=None,
                        extra_residuals=None,
                        get_contributions=False):
        '''
        Combine misfit contributions (residuals) to global or bootstrap misfits

        :param misfits: 3D array ``misfits[imodel, iresidual, 0]`` are the
            misfit contributions (residuals) ``misfits[imodel, iresidual, 1]``
            are the normalisation contributions. It is also possible to give
            the misfit and normalisation contributions for a single model as
            ``misfits[iresidual, 0]`` and misfits[iresidual, 1]`` in which
            case, the first dimension (imodel) of the result will be stipped
            off.

        :param extra_weights: if given, 2D array of extra weights to be applied
            to the contributions, indexed as
            ``extra_weights[ibootstrap, iresidual]``.

        :param extra_residuals: if given, 2D array of perturbations to be added
            to the residuals, indexed as
            ``extra_residuals[ibootstrap, iresidual]``.

        :param get_contributions: get the weighted and perturbed contributions
            (don't do the sum).

        :returns: if no *extra_weights* or *extra_residuals* are given, a 1D
            array indexed as ``misfits[imodel]`` containing the global misfit
            for each model is returned, otherwise a 2D array
            ``misfits[imodel, ibootstrap]`` with the misfit for every model and
            weighting/residual set is returned.
        '''

        exp, root = self.get_norm_functions()

        if misfits.ndim == 2:
            misfits = misfits[num.newaxis, :, :]
            return self.combine_misfits(misfits, extra_weights,
                                        extra_residuals,
                                        get_contributions)[0, ...]

        assert misfits.ndim == 3
        assert extra_weights is None or extra_weights.ndim == 2
        assert extra_residuals is None or extra_residuals.ndim == 2

        if extra_weights is not None or extra_residuals is not None:
            if extra_weights is not None:
                w = extra_weights[num.newaxis, :, :] \
                    * self.get_target_weights()[num.newaxis, num.newaxis, :] \
                    * self.inter_family_weights2(
                        misfits[:, :, 1])[:, num.newaxis, :]
            else:
                w = 1.0

            if extra_residuals is not None:
                r = extra_residuals[num.newaxis, :, :]
            else:
                r = 0.0

            if get_contributions:
                return exp(w*(misfits[:, num.newaxis, :, 0]+r)) \
                    / num.nansum(
                        exp(w*misfits[:, num.newaxis, :, 1]),
                        axis=2)[:, :, num.newaxis]

            res = root(
                num.nansum(exp(w * (misfits[:, num.newaxis, :, 0] + r)),
                           axis=2) /
                num.nansum(exp(w * (misfits[:, num.newaxis, :, 1])), axis=2))
            assert res[res < 0].size == 0
            return res
        else:
            w = self.get_target_weights()[num.newaxis, :] \
                * self.inter_family_weights2(misfits[:, :, 1])

            r = self.get_target_weights()[num.newaxis, :] \
                * self.inter_family_weights2(misfits[:, :, 1])

            if get_contributions:
                return exp(w*misfits[:, :, 0]) \
                    / num.nansum(
                        exp(w*misfits[:, :, 1]),
                        axis=1)[:, num.newaxis]

            return root(
                num.nansum(exp(w * misfits[:, :, 0]), axis=1) /
                num.nansum(exp(w * misfits[:, :, 1]), axis=1))

    def make_family_mask(self):
        family_names = set()
        families = num.zeros(self.nmisfits, dtype=num.int)

        idx = 0
        for itarget, target in enumerate(self.targets):
            family_names.add(target.normalisation_family)
            families[idx:idx + target.nmisfits] = len(family_names) - 1
            idx += target.nmisfits

        return families, len(family_names)

    def get_family_mask(self):
        if self._family_mask is None:
            self._family_mask = self.make_family_mask()

        return self._family_mask

    def evaluate(self, x, mask=None, result_mode='full', targets=None):
        source = self.get_source(x)
        engine = self.get_engine()

        self.set_target_parameter_values(x)

        if mask is not None and targets is not None:
            raise ValueError('Mask cannot be defined with targets set.')
        targets = targets if targets is not None else self.targets

        for target in targets:
            target.set_result_mode(result_mode)

        modelling_targets = []
        t2m_map = {}
        for itarget, target in enumerate(targets):
            t2m_map[target] = target.prepare_modelling(engine, source, targets)
            if mask is None or mask[itarget]:
                modelling_targets.extend(t2m_map[target])

        u2m_map = {}
        for imtarget, mtarget in enumerate(modelling_targets):
            if mtarget not in u2m_map:
                u2m_map[mtarget] = []

            u2m_map[mtarget].append(imtarget)

        modelling_targets_unique = list(u2m_map.keys())

        resp = engine.process(source,
                              modelling_targets_unique,
                              nthreads=self.nthreads)
        modelling_results_unique = list(resp.results_list[0])

        modelling_results = [None] * len(modelling_targets)

        for mtarget, mresult in zip(modelling_targets_unique,
                                    modelling_results_unique):

            for itarget in u2m_map[mtarget]:
                modelling_results[itarget] = mresult

        imt = 0
        results = []
        for itarget, target in enumerate(targets):
            nmt_this = len(t2m_map[target])
            if mask is None or mask[itarget]:
                result = target.finalize_modelling(
                    engine, source, t2m_map[target],
                    modelling_results[imt:imt + nmt_this])

                imt += nmt_this
            else:
                result = gf.SeismosizerError(
                    'target was excluded from modelling')

            results.append(result)

        return results

    def misfits(self, x, mask=None):
        results = self.evaluate(x, mask=mask, result_mode='sparse')
        misfits = num.full((self.nmisfits, 2), num.nan)

        imisfit = 0
        for target, result in zip(self.targets, results):
            if isinstance(result, MisfitResult):
                misfits[imisfit:imisfit + target.nmisfits, :] = result.misfits

            imisfit += target.nmisfits

        return misfits

    def forward(self, x):
        source = self.get_source(x)
        engine = self.get_engine()

        plain_targets = []
        for target in self.targets:
            plain_targets.extend(target.get_plain_targets(engine, source))

        resp = engine.process(source, plain_targets)

        results = []
        for target, result in zip(plain_targets, resp.results_list[0]):
            if isinstance(result, gf.SeismosizerError):
                logger.debug('%s.%s.%s.%s: %s' % (target.codes +
                                                  (str(result), )))
            else:
                results.append(result)

        return results

    def get_random_model(self, ntries_limit=100):
        xbounds = self.get_parameter_bounds()

        for _ in range(ntries_limit):
            x = self.random_uniform(xbounds, rstate=g_rstate)
            try:
                return self.preconstrain(x)

            except Forbidden:
                pass

        raise GrondError(
            'Could not find any suitable candidate sample within %i tries' %
            (ntries_limit))
示例#21
0
class GNSSCampaignMisfitResult(MisfitResult):
    """Carries the observations for a target and corresponding synthetics. """
    statics_syn = Dict.T(optional=True,
                         help='Synthetic gnss surface displacements')
    statics_obs = Dict.T(optional=True,
                         help='Observed gnss surface displacements')
示例#22
0
 class Y(Object):
     xs = Dict.T(X.T(), X.T(), default={X(a=2): X(a=3)})
示例#23
0
class PolygonMaskConfig(PluginConfig):
    polygons = Dict.T(optional=True, default={})
    applied = Bool.T(default=True)
示例#24
0
    def testOptionalDefault(self):

        from pyrocko.guts_array import Array, array_equal
        import numpy as num
        assert_ae = num.testing.assert_almost_equal

        def array_equal_noneaware(a, b):
            if a is None:
                return b is None
            elif b is None:
                return a is None
            else:
                return array_equal(a, b)

        data = [
            ('a', Int.T(),
                [None, 0, 1, 2],
                ['aerr', 0, 1, 2]),
            ('b', Int.T(optional=True),
                [None, 0, 1, 2],
                [None, 0, 1, 2]),
            ('c', Int.T(default=1),
                [None, 0, 1, 2],
                [1, 0, 1, 2]),
            ('d', Int.T(default=1, optional=True),
                [None, 0, 1, 2],
                [1, 0, 1, 2]),
            ('e', List.T(Int.T()),
                [None, [], [1], [2]],
                [[], [], [1], [2]]),
            ('f', List.T(Int.T(), optional=True),
                [None, [], [1], [2]],
                [None, [], [1], [2]]),
            ('g', List.T(Int.T(), default=[1]), [
                None, [], [1], [2]],
                [[1], [], [1], [2]]),
            ('h', List.T(Int.T(), default=[1], optional=True),
                [None, [], [1], [2]],
                [[1], [], [1], [2]]),
            ('i', Tuple.T(2, Int.T()),
                [None, (1, 2)],
                ['err', (1, 2)]),
            ('j', Tuple.T(2, Int.T(), optional=True),
                [None, (1, 2)],
                [None, (1, 2)]),
            ('k', Tuple.T(2, Int.T(), default=(1, 2)),
                [None, (1, 2), (3, 4)],
                [(1, 2), (1, 2), (3, 4)]),
            ('l', Tuple.T(2, Int.T(), default=(1, 2), optional=True),
                [None, (1, 2), (3, 4)],
                [(1, 2), (1, 2), (3, 4)]),
            ('i2', Tuple.T(None, Int.T()),
                [None, (1, 2)],
                [(), (1, 2)]),
            ('j2', Tuple.T(None, Int.T(), optional=True),
                [None, (), (3, 4)],
                [None, (), (3, 4)]),
            ('k2', Tuple.T(None, Int.T(), default=(1,)),
                [None, (), (3, 4)],
                [(1,), (), (3, 4)]),
            ('l2', Tuple.T(None, Int.T(), default=(1,), optional=True),
                [None, (), (3, 4)],
                [(1,), (), (3, 4)]),
            ('m', Array.T(shape=(None,), dtype=num.int, serialize_as='list'),
                [num.arange(0), num.arange(2)],
                [num.arange(0), num.arange(2)]),
            ('n', Array.T(shape=(None,), dtype=num.int, serialize_as='list',
                          optional=True),
                [None, num.arange(0), num.arange(2)],
                [None, num.arange(0), num.arange(2)]),
            ('o', Array.T(shape=(None,), dtype=num.int, serialize_as='list',
                          default=num.arange(2)),
                [None, num.arange(0), num.arange(2), num.arange(3)],
                [num.arange(2), num.arange(0), num.arange(2), num.arange(3)]),
            ('p', Array.T(shape=(None,), dtype=num.int, serialize_as='list',
                          default=num.arange(2), optional=True),
                [None, num.arange(0), num.arange(2), num.arange(3)],
                [num.arange(2), num.arange(0), num.arange(2), num.arange(3)]),
            ('q', Dict.T(String.T(), Int.T()),
                [None, {}, {'a': 1}],
                [{}, {}, {'a': 1}]),
            ('r', Dict.T(String.T(), Int.T(), optional=True),
                [None, {}, {'a': 1}],
                [None, {}, {'a': 1}]),
            ('s', Dict.T(String.T(), Int.T(), default={'a': 1}),
                [None, {}, {'a': 1}],
                [{'a': 1}, {}, {'a': 1}]),
            ('t', Dict.T(String.T(), Int.T(), default={'a': 1}, optional=True),
                [None, {}, {'a': 1}],
                [{'a': 1}, {}, {'a': 1}]),
        ]

        for k, t, vals, exp, in data:
            last = [None]

            class A(Object):
                def __init__(self, **kwargs):
                    last[0] = len(kwargs)
                    Object.__init__(self, **kwargs)

                v = t

            A.T.class_signature()

            for v, e in zip(vals, exp):
                if isinstance(e, str) and e == 'aerr':
                    with self.assertRaises(ArgumentError):
                        if v is not None:
                            a1 = A(v=v)
                        else:
                            a1 = A()

                    continue
                else:
                    if v is not None:
                        a1 = A(v=v)
                    else:
                        a1 = A()

                if isinstance(e, str) and e == 'err':
                    with self.assertRaises(ValidationError):
                        a1.validate()
                else:
                    a1.validate()
                    a2 = load_string(dump(a1))
                    if isinstance(e, num.ndarray):
                        assert last[0] == int(
                            not (array_equal_noneaware(t.default(), a1.v)
                                 and t.optional))
                        assert_ae(a1.v, e)
                        assert_ae(a1.v, e)
                    else:
                        assert last[0] == int(
                            not(t.default() == a1.v and t.optional))
                        self.assertEqual(a1.v, e)
                        self.assertEqual(a2.v, e)
示例#25
0
 class A(Object):
     d = Dict.T(Int.T(), Float.T())
示例#26
0
class Map(Object):
    lat = Float.T(optional=True)
    lon = Float.T(optional=True)
    radius = Float.T(optional=True)
    width = Float.T(default=20.)
    height = Float.T(default=14.)
    margins = List.T(Float.T())
    illuminate = Bool.T(default=True)
    skip_feature_factor = Float.T(default=0.02)
    show_grid = Bool.T(default=False)
    show_topo = Bool.T(default=True)
    show_topo_scale = Bool.T(default=False)
    show_center_mark = Bool.T(default=False)
    show_rivers = Bool.T(default=True)
    illuminate_factor_land = Float.T(default=0.5)
    illuminate_factor_ocean = Float.T(default=0.25)
    color_wet = Tuple.T(3, Int.T(), default=(216, 242, 254))
    color_dry = Tuple.T(3, Int.T(), default=(172, 208, 165))
    topo_resolution_min = Float.T(
        default=40., help='minimum resolution of topography [dpi]')
    topo_resolution_max = Float.T(
        default=200., help='maximum resolution of topography [dpi]')
    replace_topo_color_only = FloatTile.T(
        optional=True,
        help='replace topo color while keeping topographic shading')
    topo_cpt_wet = String.T(default='light_sea')
    topo_cpt_dry = String.T(default='light_land')
    axes_layout = String.T(optional=True)
    custom_cities = List.T(City.T())
    gmt_config = Dict.T(String.T(), String.T())
    comment = String.T(optional=True)

    def __init__(self, **kwargs):
        Object.__init__(self, **kwargs)
        self._gmt = None
        self._scaler = None
        self._widget = None
        self._corners = None
        self._wesn = None
        self._minarea = None
        self._coastline_resolution = None
        self._rivers = None
        self._dems = None
        self._have_topo_land = None
        self._have_topo_ocean = None
        self._jxyr = None
        self._prep_topo_have = None
        self._labels = []

    def save(self,
             outpath,
             resolution=75.,
             oversample=2.,
             size=None,
             width=None,
             height=None):
        '''Save the image.

        Save the image to *outpath*. The format is determined by the filename
        extension. Formats are handled as follows: ``'.eps'`` and ``'.ps'``
        produce EPS and PS, respectively, directly with GMT. If the file name
        ends with ``'.pdf'``, GMT output is fed through ``gmtpy-epstopdf`` to
        create a PDF file. For any other filename extension, output is first
        converted to PDF with ``gmtpy-epstopdf``, then with ``pdftocairo`` to
        PNG with a resolution oversampled by the factor *oversample* and
        finally the PNG is downsampled and converted to the target format with
        ``convert``. The resolution of rasterized target image can be
        controlled either by *resolution* in DPI or by specifying *width* or
        *height* or *size*, where the latter fits the image into a square with
        given side length.'''

        gmt = self.gmt
        self.draw_labels()
        self.draw_axes()
        if self.show_topo and self.show_topo_scale:
            self._draw_topo_scale()

        gmt = self._gmt
        if outpath.endswith('.eps'):
            tmppath = gmt.tempfilename() + '.eps'
        elif outpath.endswith('.ps'):
            tmppath = gmt.tempfilename() + '.ps'
        else:
            tmppath = gmt.tempfilename() + '.pdf'

        gmt.save(tmppath)

        if any(outpath.endswith(x) for x in ('.eps', '.ps', '.pdf')):
            shutil.copy(tmppath, outpath)
        else:
            convert_graph(tmppath,
                          outpath,
                          resolution=resolution,
                          oversample=oversample,
                          size=size,
                          width=width,
                          height=height)

    @property
    def scaler(self):
        if self._scaler is None:
            self._setup_geometry()

        return self._scaler

    @property
    def wesn(self):
        if self._wesn is None:
            self._setup_geometry()

        return self._wesn

    @property
    def widget(self):
        if self._widget is None:
            self._setup()

        return self._widget

    @property
    def layout(self):
        if self._layout is None:
            self._setup()

        return self._layout

    @property
    def jxyr(self):
        if self._jxyr is None:
            self._setup()

        return self._jxyr

    @property
    def pxyr(self):
        if self._pxyr is None:
            self._setup()

        return self._pxyr

    @property
    def gmt(self):
        if self._gmt is None:
            self._setup()

        if self._have_topo_ocean is None:
            self._draw_background()

        return self._gmt

    def _setup(self):
        if not self._widget:
            self._setup_geometry()

        self._setup_lod()
        self._setup_gmt()

    def _setup_geometry(self):
        wpage, hpage = self.width, self.height
        ml, mr, mt, mb = self._expand_margins()
        wpage -= ml + mr
        hpage -= mt + mb

        wreg = self.radius * 2.0
        hreg = self.radius * 2.0
        if wpage >= hpage:
            wreg *= wpage / hpage
        else:
            hreg *= hpage / wpage

        self._corners = corners(self.lon, self.lat, wreg, hreg)
        west, east, south, north = extent(self.lon, self.lat, wreg, hreg, 10)

        x, y, z = ((west, east), (south, north), (-6000., 4500.))

        xax = gmtpy.Ax(mode='min-max', approx_ticks=4.)
        yax = gmtpy.Ax(mode='min-max', approx_ticks=4.)
        zax = gmtpy.Ax(mode='min-max',
                       inc=1000.,
                       label='Height',
                       scaled_unit='km',
                       scaled_unit_factor=0.001)

        scaler = gmtpy.ScaleGuru(data_tuples=[(x, y, z)], axes=(xax, yax, zax))

        par = scaler.get_params()

        west = par['xmin']
        east = par['xmax']
        south = par['ymin']
        north = par['ymax']

        self._wesn = west, east, south, north
        self._scaler = scaler

    def _setup_lod(self):
        w, e, s, n = self._wesn
        if self.radius > 1500. * km:
            coastline_resolution = 'i'
            rivers = False
        else:
            coastline_resolution = 'f'
            rivers = True

        self._minarea = (self.skip_feature_factor * self.radius / km)**2

        self._coastline_resolution = coastline_resolution
        self._rivers = rivers

        self._prep_topo_have = {}
        self._dems = {}

        cm2inch = gmtpy.cm / gmtpy.inch

        dmin = 2.0 * self.radius * m2d / (self.topo_resolution_max *
                                          (self.height * cm2inch))
        dmax = 2.0 * self.radius * m2d / (self.topo_resolution_min *
                                          (self.height * cm2inch))

        for k in ['ocean', 'land']:
            self._dems[k] = topo.select_dem_names(k, dmin, dmax, self._wesn)
            if self._dems[k]:
                logger.debug('using topography dataset %s for %s' %
                             (','.join(self._dems[k]), k))

    def _expand_margins(self):
        if len(self.margins) == 0 or len(self.margins) > 4:
            ml = mr = mt = mb = 2.0
        elif len(self.margins) == 1:
            ml = mr = mt = mb = self.margins[0]
        elif len(self.margins) == 2:
            ml = mr = self.margins[0]
            mt = mb = self.margins[1]
        elif len(self.margins) == 4:
            ml, mr, mt, mb = self.margins

        return ml, mr, mt, mb

    def _setup_gmt(self):
        w, h = self.width, self.height
        scaler = self._scaler

        gmtconf = dict(TICK_PEN='1.25p',
                       TICK_LENGTH='0.2c',
                       ANNOT_FONT_PRIMARY='1',
                       ANNOT_FONT_SIZE_PRIMARY='12p',
                       LABEL_FONT='1',
                       LABEL_FONT_SIZE='12p',
                       CHAR_ENCODING='ISOLatin1+',
                       BASEMAP_TYPE='fancy',
                       PLOT_DEGREE_FORMAT='D',
                       PAPER_MEDIA='Custom_%ix%i' %
                       (w * gmtpy.cm, h * gmtpy.cm),
                       GRID_PEN_PRIMARY='thinnest/0/50/0',
                       DOTS_PR_INCH='1200',
                       OBLIQUE_ANNOTATION='6')

        gmtconf.update(
            (k.upper(), v) for (k, v) in self.gmt_config.iteritems())

        gmt = gmtpy.GMT(config=gmtconf)

        layout = gmt.default_layout()

        layout.set_fixed_margins(*[x * cm for x in self._expand_margins()])

        widget = layout.get_widget()
        # widget['J'] = ('-JT%g/%g' % (self.lon, self.lat)) + '/%(width)gp'
        # widget['J'] = ('-JA%g/%g' % (self.lon, self.lat)) + '/%(width)gp'
        widget['P'] = widget['J']
        widget['J'] = ('-JA%g/%g' % (self.lon, self.lat)) + '/%(width_m)gm'
        # widget['J'] = ('-JE%g/%g' % (self.lon, self.lat)) + '/%(width)gp'
        # scaler['R'] = '-R%(xmin)g/%(xmax)g/%(ymin)g/%(ymax)g'
        scaler['R'] = '-R%g/%g/%g/%gr' % self._corners

        aspect = gmtpy.aspect_for_projection(*(widget.J() + scaler.R()))
        widget.set_aspect(aspect)

        self._gmt = gmt
        self._layout = layout
        self._widget = widget
        self._jxyr = self._widget.JXY() + self._scaler.R()
        self._pxyr = self._widget.PXY() + [
            '-R%g/%g/%g/%g' % (0, widget.width(), 0, widget.height())
        ]
        self._have_drawn_axes = False
        self._have_drawn_labels = False

    def _draw_background(self):
        self._have_topo_land = False
        self._have_topo_ocean = False
        if self.show_topo:
            self._have_topo = self._draw_topo()

        self._draw_basefeatures()

    def _get_topo_tile(self, k):
        t = None
        demname = None
        for dem in self._dems[k]:
            t = topo.get(dem, self._wesn)
            demname = dem
            if t is not None:
                break

        if not t:
            raise NoTopo()

        return t, demname

    def _prep_topo(self, k):
        gmt = self._gmt
        t, demname = self._get_topo_tile(k)

        if demname not in self._prep_topo_have:

            grdfile = gmt.tempfilename()
            gmtpy.savegrd(t.x(),
                          t.y(),
                          t.data,
                          filename=grdfile,
                          naming='lonlat')

            if self.illuminate:
                if k == 'ocean':
                    factor = self.illuminate_factor_ocean
                else:
                    factor = self.illuminate_factor_land

                ilumfn = gmt.tempfilename()
                gmt.grdgradient(grdfile,
                                N='e%g' % factor,
                                A=-45,
                                G=ilumfn,
                                out_discard=True)

                ilumargs = ['-I%s' % ilumfn]
            else:
                ilumargs = []

            if self.replace_topo_color_only:
                t2 = self.replace_topo_color_only
                grdfile2 = gmt.tempfilename()

                gmtpy.savegrd(t2.x(),
                              t2.y(),
                              t2.data,
                              filename=grdfile2,
                              naming='lonlat')

                gmt.grdsample(grdfile2,
                              G=grdfile,
                              Q='l',
                              I='%g/%g' % (t.dx, t.dy),
                              R=grdfile,
                              out_discard=True)

                gmt.grdmath(grdfile,
                            '0.0',
                            'AND',
                            '=',
                            grdfile2,
                            out_discard=True)

                grdfile = grdfile2

            self._prep_topo_have[demname] = grdfile, ilumargs

        return self._prep_topo_have[demname]

    def _draw_topo(self):
        widget = self._widget
        scaler = self._scaler
        gmt = self._gmt
        cres = self._coastline_resolution
        minarea = self._minarea

        JXY = widget.JXY()
        R = scaler.R()

        try:
            grdfile, ilumargs = self._prep_topo('ocean')
            gmt.pscoast(D=cres, S='c', A=minarea, *(JXY + R))
            gmt.grdimage(grdfile,
                         C=topo.cpt(self.topo_cpt_wet),
                         *(ilumargs + JXY + R))
            gmt.pscoast(Q=True, *(JXY + R))
            self._have_topo_ocean = True
        except NoTopo:
            self._have_topo_ocean = False

        try:
            grdfile, ilumargs = self._prep_topo('land')
            gmt.pscoast(D=cres, G='c', A=minarea, *(JXY + R))
            gmt.grdimage(grdfile,
                         C=topo.cpt(self.topo_cpt_dry),
                         *(ilumargs + JXY + R))
            gmt.pscoast(Q=True, *(JXY + R))
            self._have_topo_land = True
        except NoTopo:
            self._have_topo_land = False

    def _draw_topo_scale(self, label='Elevation [km]'):
        dry = read_cpt(topo.cpt(self.topo_cpt_dry))
        wet = read_cpt(topo.cpt(self.topo_cpt_wet))
        combi = cpt_merge_wet_dry(wet, dry)
        for level in combi.levels:
            level.vmin /= km
            level.vmax /= km

        topo_cpt = self.gmt.tempfilename()
        write_cpt(combi, topo_cpt)

        (w, h), (xo, yo) = self.widget.get_size()
        self.gmt.psscale(
            D='%gp/%gp/%gp/%gph' %
            (xo + 0.5 * w, yo - 2.0 * gmtpy.cm, w, 0.5 * gmtpy.cm),
            C=topo_cpt,
            B='1:%s:' % label)

    def _draw_basefeatures(self):
        gmt = self._gmt
        cres = self._coastline_resolution
        rivers = self._rivers
        minarea = self._minarea

        color_wet = self.color_wet
        color_dry = self.color_dry

        if self.show_rivers and rivers:
            rivers = ['-Ir/0.25p,%s' % gmtpy.color(self.color_wet)]
        else:
            rivers = []

        fill = {}
        if not self._have_topo_land:
            fill['G'] = color_dry

        if not self._have_topo_ocean:
            fill['S'] = color_wet

        gmt.pscoast(D=cres,
                    W='thinnest,%s' %
                    gmtpy.color(darken(gmtpy.color_tup(color_dry))),
                    A=minarea,
                    *(rivers + self._jxyr),
                    **fill)

    def _draw_axes(self):
        gmt = self._gmt
        scaler = self._scaler
        widget = self._widget

        if self.axes_layout is None:
            if self.lat > 0.0:
                axes_layout = 'WSen'
            else:
                axes_layout = 'WseN'
        else:
            axes_layout = self.axes_layout

        scale_km = gmtpy.nice_value(self.radius / 5.) / 1000.

        if self.show_center_mark:
            gmt.psxy(in_rows=[[self.lon, self.lat]],
                     S='c20p',
                     W='2p,black',
                     *self._jxyr)

        if self.show_grid:
            btmpl = ('%(xinc)gg%(xinc)g:%(xlabel)s:/'
                     '%(yinc)gg%(yinc)g:%(ylabel)s:')
        else:
            btmpl = '%(xinc)g:%(xlabel)s:/%(yinc)g:%(ylabel)s:'

        gmt.psbasemap(B=(btmpl % scaler.get_params()) + axes_layout,
                      L=('x%gp/%gp/%g/%g/%gk' %
                         (6. / 7 * widget.width(), widget.height() / 7.,
                          self.lon, self.lat, scale_km)),
                      *self._jxyr)

        if self.comment:
            fontsize = self.gmt.to_points(
                self.gmt.gmt_config['LABEL_FONT_SIZE'])

            _, east, south, _ = self._wesn
            gmt.pstext(in_rows=[[1, 0, fontsize, 0, 0, 'BR', self.comment]],
                       N=True,
                       R=(0, 1, 0, 1),
                       D='%gp/%gp' % (-fontsize * 0.2, fontsize * 0.3),
                       *widget.PXY())

    def draw_axes(self):
        if not self._have_drawn_axes:
            self._draw_axes()
            self._have_drawn_axes = True

    def _have_coastlines(self):
        gmt = self._gmt
        cres = self._coastline_resolution
        minarea = self._minarea

        checkfile = gmt.tempfilename()

        gmt.pscoast(M=True,
                    D=cres,
                    W='thinnest/black',
                    A=minarea,
                    out_filename=checkfile,
                    *self._jxyr)

        with open(checkfile, 'r') as f:
            for line in f:
                ls = line.strip()
                if ls.startswith('#') or ls.startswith('>') or ls == '':
                    continue
                plon, plat = [float(x) for x in ls.split()]
                if point_in_region((plon, plat), self._wesn):
                    return True

        return False

    def project(self, lats, lons):
        onepoint = False
        if isinstance(lats, float) and isinstance(lons, float):
            lats = [lats]
            lons = [lons]
            onepoint = True

        j, _, _, r = self.jxyr
        (xo, yo) = self.widget.get_size()[1]
        f = StringIO()
        self.gmt.mapproject(j, r, in_columns=(lons, lats), out_stream=f, D='p')
        f.seek(0)
        data = num.loadtxt(f, ndmin=2)
        xs, ys = data.T
        if onepoint:
            xs = xs[0]
            ys = ys[0]
        return xs, ys

    def _draw_labels(self):
        if self._labels:
            fontsize = self.gmt.to_points(
                self.gmt.gmt_config['LABEL_FONT_SIZE'])

            n = len(self._labels)

            lons, lats, texts, sx, sy, styles = zip(*self._labels)

            sx = num.array(sx, dtype=num.float)
            sy = num.array(sy, dtype=num.float)

            j, _, _, r = self.jxyr
            f = StringIO()
            self.gmt.mapproject(j,
                                r,
                                in_columns=(lons, lats),
                                out_stream=f,
                                D='p')
            f.seek(0)
            data = num.loadtxt(f, ndmin=2)
            xs, ys = data.T
            dxs = num.zeros(n)
            dys = num.zeros(n)

            g = gmtpy.GMT()
            g.psbasemap('-G0', finish=True, *(j, r))
            l, b, r, t = g.bbox()
            h = (t - b)
            w = (r - l)

            for i in xrange(n):
                g = gmtpy.GMT()
                g.pstext(in_rows=[[0, 0, fontsize, 0, 1, 'BL', texts[i]]],
                         finish=True,
                         R=(0, 1, 0, 1),
                         J='x10p',
                         N=True,
                         **styles[i])

                fn = g.tempfilename()
                g.save(fn)

                (_, stderr) = Popen([
                    'gs', '-q', '-dNOPAUSE', '-dBATCH', '-r720',
                    '-sDEVICE=bbox', fn
                ],
                                    stderr=PIPE).communicate()

                dx, dy = None, None
                for line in stderr.splitlines():
                    if line.startswith('%%HiResBoundingBox:'):
                        l, b, r, t = [float(x) for x in line.split()[-4:]]
                        dx, dy = r - l, t - b

                dxs[i] = dx
                dys[i] = dy

            la = num.logical_and
            anchors_ok = (
                la(xs + sx + dxs < w, ys + sy + dys < h),
                la(xs - sx - dxs > 0., ys - sy - dys > 0.),
                la(xs + sx + dxs < w, ys - sy - dys > 0.),
                la(xs - sx - dxs > 0., ys + sy + dys < h),
            )

            arects = [(xs, ys, xs + sx + dxs, ys + sy + dys),
                      (xs - sx - dxs, ys - sy - dys, xs, ys),
                      (xs, ys - sy - dys, xs + sx + dxs, ys),
                      (xs - sx - dxs, ys, xs, ys + sy + dys)]

            def no_points_in_rect(xs, ys, xmin, ymin, xmax, ymax):
                xx = not num.any(
                    la(la(xmin < xs, xs < xmax), la(ymin < ys, ys < ymax)))
                return xx

            for i in xrange(n):
                for ianch in xrange(4):
                    anchors_ok[ianch][i] &= no_points_in_rect(
                        xs, ys, *[xxx[i] for xxx in arects[ianch]])

            anchor_choices = []
            anchor_take = []
            for i in xrange(n):
                choices = [
                    ianch for ianch in xrange(4) if anchors_ok[ianch][i]
                ]
                anchor_choices.append(choices)
                if choices:
                    anchor_take.append(choices[0])
                else:
                    anchor_take.append(None)

            def roverlaps(a, b):
                return (a[0] < b[2] and b[0] < a[2] and a[1] < b[3]
                        and b[1] < a[3])

            def cost(anchor_take):
                noverlaps = 0
                for i in xrange(n):
                    for j in xrange(n):
                        if i != j:
                            i_take = anchor_take[i]
                            j_take = anchor_take[j]
                            if i_take is None or j_take is None:
                                continue
                            r_i = [xxx[i] for xxx in arects[i_take]]
                            r_j = [xxx[j] for xxx in arects[j_take]]
                            if roverlaps(r_i, r_j):
                                noverlaps += 1

                return noverlaps

            cur_cost = cost(anchor_take)
            imax = 30
            while cur_cost != 0 and imax > 0:
                for i in xrange(n):
                    for t in anchor_choices[i]:
                        anchor_take_new = list(anchor_take)
                        anchor_take_new[i] = t
                        new_cost = cost(anchor_take_new)
                        if new_cost < cur_cost:
                            anchor_take = anchor_take_new
                            cur_cost = new_cost

                imax -= 1

            while cur_cost != 0:
                for i in xrange(n):
                    anchor_take_new = list(anchor_take)
                    anchor_take_new[i] = None
                    new_cost = cost(anchor_take_new)
                    if new_cost < cur_cost:
                        anchor_take = anchor_take_new
                        cur_cost = new_cost
                        break

            anchor_strs = ['BL', 'TR', 'TL', 'BR']

            for i in xrange(n):
                ianchor = anchor_take[i]
                if ianchor is not None:
                    anchor = anchor_strs[ianchor]
                    yoff = [-sy[i], sy[i]][anchor[0] == 'B']
                    xoff = [-sx[i], sx[i]][anchor[1] == 'L']
                    row = (lons[i], lats[i], fontsize, 0, 1, anchor, texts[i])
                    self.gmt.pstext(in_rows=[row],
                                    D='%gp/%gp' % (xoff, yoff),
                                    *self.jxyr,
                                    **styles[i])

    def draw_labels(self):
        self.gmt
        if not self._have_drawn_labels:
            self._draw_labels()
            self._have_drawn_labels = True

    def add_label(self, lat, lon, text, offset_x=5., offset_y=5., style={}):
        self._labels.append((lon, lat, text, offset_x, offset_y, style))

    def cities_in_region(self):
        from pyrocko import geonames
        cities = geonames.get_cities_region(region=self.wesn, minpop=0)
        cities.extend(self.custom_cities)
        cities.sort(key=lambda x: x.population)
        return cities

    def draw_cities(
        self,
        exact=None,
        include=[],
        exclude=['City of London'],  # duplicate entry in geonames
        nmax_soft=10,
        psxy_style=dict(S='s5p', G='black')):

        cities = self.cities_in_region()

        if exact is not None:
            cities = [c for c in cities if c.name in exact]
            minpop = None
        else:
            cities = [c for c in cities if c.name not in exclude]
            minpop = 10**3
            for minpop_new in [1e3, 3e3, 1e4, 3e4, 1e5, 3e5, 1e6, 3e6, 1e7]:
                cities_new = [c for c in cities if c.population > minpop_new]

                if len(cities_new) == 0 or (len(cities_new) < 3
                                            and len(cities) < nmax_soft * 2):
                    break

                cities = cities_new
                minpop = minpop_new
                if len(cities) <= nmax_soft:
                    break

        if cities:
            lats = [c.lat for c in cities]
            lons = [c.lon for c in cities]

            self.gmt.psxy(in_columns=(lons, lats), *self.jxyr, **psxy_style)

            for c in cities:
                try:
                    text = c.name.encode('iso-8859-1')
                except UnicodeEncodeError:
                    text = c.asciiname

                self.add_label(c.lat, c.lon, text)

        self._cities_minpop = minpop
示例#27
0
class Map(Object):
    lat = Float.T(optional=True)
    lon = Float.T(optional=True)
    radius = Float.T(optional=True)
    width = Float.T(default=20.)
    height = Float.T(default=14.)
    margins = List.T(Float.T())
    illuminate = Bool.T(default=True)
    skip_feature_factor = Float.T(default=0.02)
    show_grid = Bool.T(default=False)
    show_topo = Bool.T(default=True)
    show_topo_scale = Bool.T(default=False)
    show_center_mark = Bool.T(default=False)
    show_rivers = Bool.T(default=True)
    show_plates = Bool.T(default=False)
    illuminate_factor_land = Float.T(default=0.5)
    illuminate_factor_ocean = Float.T(default=0.25)
    color_wet = Tuple.T(3, Int.T(), default=(216, 242, 254))
    color_dry = Tuple.T(3, Int.T(), default=(172, 208, 165))
    topo_resolution_min = Float.T(
        default=40.,
        help='minimum resolution of topography [dpi]')
    topo_resolution_max = Float.T(
        default=200.,
        help='maximum resolution of topography [dpi]')
    replace_topo_color_only = FloatTile.T(
        optional=True,
        help='replace topo color while keeping topographic shading')
    topo_cpt_wet = String.T(default='light_sea')
    topo_cpt_dry = String.T(default='light_land')
    axes_layout = String.T(optional=True)
    custom_cities = List.T(City.T())
    gmt_config = Dict.T(String.T(), String.T())
    comment = String.T(optional=True)

    def __init__(self, gmtversion='newest', **kwargs):
        Object.__init__(self, **kwargs)
        self._gmt = None
        self._scaler = None
        self._widget = None
        self._corners = None
        self._wesn = None
        self._minarea = None
        self._coastline_resolution = None
        self._rivers = None
        self._dems = None
        self._have_topo_land = None
        self._have_topo_ocean = None
        self._jxyr = None
        self._prep_topo_have = None
        self._labels = []
        self._area_labels = []
        self._gmtversion = gmtversion

    def save(self, outpath, resolution=75., oversample=2., size=None,
             width=None, height=None):

        '''
        Save the image.

        Save the image to ``outpath``. The format is determined by the filename
        extension. Formats are handled as follows: ``'.eps'`` and ``'.ps'``
        produce EPS and PS, respectively, directly with GMT. If the file name
        ends with ``'.pdf'``, GMT output is fed through ``gmtpy-epstopdf`` to
        create a PDF file. For any other filename extension, output is first
        converted to PDF with ``gmtpy-epstopdf``, then with ``pdftocairo`` to
        PNG with a resolution oversampled by the factor ``oversample`` and
        finally the PNG is downsampled and converted to the target format with
        ``convert``. The resolution of rasterized target image can be
        controlled either by ``resolution`` in DPI or by specifying ``width``
        or ``height`` or ``size``, where the latter fits the image into a
        square with given side length.
        '''

        gmt = self.gmt
        self.draw_labels()
        self.draw_axes()
        if self.show_topo and self.show_topo_scale:
            self._draw_topo_scale()

        gmt.save(outpath, resolution=resolution, oversample=oversample,
                 size=size, width=width, height=height)

    @property
    def scaler(self):
        if self._scaler is None:
            self._setup_geometry()

        return self._scaler

    @property
    def wesn(self):
        if self._wesn is None:
            self._setup_geometry()

        return self._wesn

    @property
    def widget(self):
        if self._widget is None:
            self._setup()

        return self._widget

    @property
    def layout(self):
        if self._layout is None:
            self._setup()

        return self._layout

    @property
    def jxyr(self):
        if self._jxyr is None:
            self._setup()

        return self._jxyr

    @property
    def pxyr(self):
        if self._pxyr is None:
            self._setup()

        return self._pxyr

    @property
    def gmt(self):
        if self._gmt is None:
            self._setup()

        if self._have_topo_ocean is None:
            self._draw_background()

        return self._gmt

    def _setup(self):
        if not self._widget:
            self._setup_geometry()

        self._setup_lod()
        self._setup_gmt()

    def _setup_geometry(self):
        wpage, hpage = self.width, self.height
        ml, mr, mt, mb = self._expand_margins()
        wpage -= ml + mr
        hpage -= mt + mb

        wreg = self.radius * 2.0
        hreg = self.radius * 2.0
        if wpage >= hpage:
            wreg *= wpage/hpage
        else:
            hreg *= hpage/wpage

        self._wreg = wreg
        self._hreg = hreg

        self._corners = corners(self.lon, self.lat, wreg, hreg)
        west, east, south, north = extent(self.lon, self.lat, wreg, hreg, 10)

        x, y, z = ((west, east), (south, north), (-6000., 4500.))

        xax = gmtpy.Ax(mode='min-max', approx_ticks=4.)
        yax = gmtpy.Ax(mode='min-max', approx_ticks=4.)
        zax = gmtpy.Ax(mode='min-max', inc=1000., label='Height',
                       scaled_unit='km', scaled_unit_factor=0.001)

        scaler = gmtpy.ScaleGuru(data_tuples=[(x, y, z)], axes=(xax, yax, zax))

        par = scaler.get_params()

        west = par['xmin']
        east = par['xmax']
        south = par['ymin']
        north = par['ymax']

        self._wesn = west, east, south, north
        self._scaler = scaler

    def _setup_lod(self):
        w, e, s, n = self._wesn
        if self.radius > 1500.*km:
            coastline_resolution = 'i'
            rivers = False
        else:
            coastline_resolution = 'f'
            rivers = True

        self._minarea = (self.skip_feature_factor * self.radius/km)**2

        self._coastline_resolution = coastline_resolution
        self._rivers = rivers

        self._prep_topo_have = {}
        self._dems = {}

        cm2inch = gmtpy.cm/gmtpy.inch

        dmin = 2.0 * self.radius * m2d / (self.topo_resolution_max *
                                          (self.height * cm2inch))
        dmax = 2.0 * self.radius * m2d / (self.topo_resolution_min *
                                          (self.height * cm2inch))

        for k in ['ocean', 'land']:
            self._dems[k] = topo.select_dem_names(k, dmin, dmax, self._wesn)
            if self._dems[k]:
                logger.debug('using topography dataset %s for %s'
                             % (','.join(self._dems[k]), k))

    def _expand_margins(self):
        if len(self.margins) == 0 or len(self.margins) > 4:
            ml = mr = mt = mb = 2.0
        elif len(self.margins) == 1:
            ml = mr = mt = mb = self.margins[0]
        elif len(self.margins) == 2:
            ml = mr = self.margins[0]
            mt = mb = self.margins[1]
        elif len(self.margins) == 4:
            ml, mr, mt, mb = self.margins

        return ml, mr, mt, mb

    def _setup_gmt(self):
        w, h = self.width, self.height
        scaler = self._scaler

        if gmtpy.is_gmt5(self._gmtversion):
            gmtconf = dict(
                MAP_TICK_PEN_PRIMARY='1.25p',
                MAP_TICK_PEN_SECONDARY='1.25p',
                MAP_TICK_LENGTH_PRIMARY='0.2c',
                MAP_TICK_LENGTH_SECONDARY='0.6c',
                FONT_ANNOT_PRIMARY='12p,1,black',
                FONT_LABEL='12p,1,black',
                PS_CHAR_ENCODING='ISOLatin1+',
                MAP_FRAME_TYPE='fancy',
                FORMAT_GEO_MAP='D',
                PS_MEDIA='Custom_%ix%i' % (
                    w*gmtpy.cm,
                    h*gmtpy.cm),
                PS_PAGE_ORIENTATION='portrait',
                MAP_GRID_PEN_PRIMARY='thinnest,0/50/0',
                MAP_ANNOT_OBLIQUE='6')
        else:
            gmtconf = dict(
                TICK_PEN='1.25p',
                TICK_LENGTH='0.2c',
                ANNOT_FONT_PRIMARY='1',
                ANNOT_FONT_SIZE_PRIMARY='12p',
                LABEL_FONT='1',
                LABEL_FONT_SIZE='12p',
                CHAR_ENCODING='ISOLatin1+',
                BASEMAP_TYPE='fancy',
                PLOT_DEGREE_FORMAT='D',
                PAPER_MEDIA='Custom_%ix%i' % (
                    w*gmtpy.cm,
                    h*gmtpy.cm),
                GRID_PEN_PRIMARY='thinnest/0/50/0',
                DOTS_PR_INCH='1200',
                OBLIQUE_ANNOTATION='6')

        gmtconf.update(
            (k.upper(), v) for (k, v) in self.gmt_config.items())

        gmt = gmtpy.GMT(config=gmtconf, version=self._gmtversion)

        layout = gmt.default_layout()

        layout.set_fixed_margins(*[x*cm for x in self._expand_margins()])

        widget = layout.get_widget()
        widget['P'] = widget['J']
        widget['J'] = ('-JA%g/%g' % (self.lon, self.lat)) + '/%(width)gp'
        scaler['R'] = '-R%g/%g/%g/%gr' % self._corners

        # aspect = gmtpy.aspect_for_projection(
        #     gmt.installation['version'], *(widget.J() + scaler.R()))

        aspect = self._map_aspect(jr=widget.J() + scaler.R())
        widget.set_aspect(aspect)

        self._gmt = gmt
        self._layout = layout
        self._widget = widget
        self._jxyr = self._widget.JXY() + self._scaler.R()
        self._pxyr = self._widget.PXY() + [
            '-R%g/%g/%g/%g' % (0, widget.width(), 0, widget.height())]
        self._have_drawn_axes = False
        self._have_drawn_labels = False

    def _draw_background(self):
        self._have_topo_land = False
        self._have_topo_ocean = False
        if self.show_topo:
            self._have_topo = self._draw_topo()

        self._draw_basefeatures()

    def _get_topo_tile(self, k):
        t = None
        demname = None
        for dem in self._dems[k]:
            t = topo.get(dem, self._wesn)
            demname = dem
            if t is not None:
                break

        if not t:
            raise NoTopo()

        return t, demname

    def _prep_topo(self, k):
        gmt = self._gmt
        t, demname = self._get_topo_tile(k)

        if demname not in self._prep_topo_have:

            grdfile = gmt.tempfilename()

            is_flat = num.all(t.data[0] == t.data)

            gmtpy.savegrd(
                t.x(), t.y(), t.data, filename=grdfile, naming='lonlat')

            if self.illuminate and not is_flat:
                if k == 'ocean':
                    factor = self.illuminate_factor_ocean
                else:
                    factor = self.illuminate_factor_land

                ilumfn = gmt.tempfilename()
                gmt.grdgradient(
                    grdfile,
                    N='e%g' % factor,
                    A=-45,
                    G=ilumfn,
                    out_discard=True)

                ilumargs = ['-I%s' % ilumfn]
            else:
                ilumargs = []

            if self.replace_topo_color_only:
                t2 = self.replace_topo_color_only
                grdfile2 = gmt.tempfilename()

                gmtpy.savegrd(
                    t2.x(), t2.y(), t2.data, filename=grdfile2,
                    naming='lonlat')

                if gmt.is_gmt5():
                    gmt.grdsample(
                        grdfile2,
                        G=grdfile,
                        n='l',
                        I='%g/%g' % (t.dx, t.dy),  # noqa
                        R=grdfile,
                        out_discard=True)
                else:
                    gmt.grdsample(
                        grdfile2,
                        G=grdfile,
                        Q='l',
                        I='%g/%g' % (t.dx, t.dy),  # noqa
                        R=grdfile,
                        out_discard=True)

                gmt.grdmath(
                    grdfile, '0.0', 'AND', '=', grdfile2,
                    out_discard=True)

                grdfile = grdfile2

            self._prep_topo_have[demname] = grdfile, ilumargs

        return self._prep_topo_have[demname]

    def _draw_topo(self):
        widget = self._widget
        scaler = self._scaler
        gmt = self._gmt
        cres = self._coastline_resolution
        minarea = self._minarea

        JXY = widget.JXY()
        R = scaler.R()

        try:
            grdfile, ilumargs = self._prep_topo('ocean')
            gmt.pscoast(D=cres, S='c', A=minarea, *(JXY+R))
            gmt.grdimage(grdfile, C=topo.cpt(self.topo_cpt_wet),
                         *(ilumargs+JXY+R))
            gmt.pscoast(Q=True, *(JXY+R))
            self._have_topo_ocean = True
        except NoTopo:
            self._have_topo_ocean = False

        try:
            grdfile, ilumargs = self._prep_topo('land')
            gmt.pscoast(D=cres, G='c', A=minarea, *(JXY+R))
            gmt.grdimage(grdfile, C=topo.cpt(self.topo_cpt_dry),
                         *(ilumargs+JXY+R))
            gmt.pscoast(Q=True, *(JXY+R))
            self._have_topo_land = True
        except NoTopo:
            self._have_topo_land = False

    def _draw_topo_scale(self, label='Elevation [km]'):
        dry = read_cpt(topo.cpt(self.topo_cpt_dry))
        wet = read_cpt(topo.cpt(self.topo_cpt_wet))
        combi = cpt_merge_wet_dry(wet, dry)
        for level in combi.levels:
            level.vmin /= km
            level.vmax /= km

        topo_cpt = self.gmt.tempfilename() + '.cpt'
        write_cpt(combi, topo_cpt)

        (w, h), (xo, yo) = self.widget.get_size()
        self.gmt.psscale(
            D='%gp/%gp/%gp/%gph' % (xo + 0.5*w, yo - 2.0*gmtpy.cm, w,
                                    0.5*gmtpy.cm),
            C=topo_cpt,
            B='1:%s:' % label)

    def _draw_basefeatures(self):
        gmt = self._gmt
        cres = self._coastline_resolution
        rivers = self._rivers
        minarea = self._minarea

        color_wet = self.color_wet
        color_dry = self.color_dry

        if self.show_rivers and rivers:
            rivers = ['-Ir/0.25p,%s' % gmtpy.color(self.color_wet)]
        else:
            rivers = []

        fill = {}
        if not self._have_topo_land:
            fill['G'] = color_dry

        if not self._have_topo_ocean:
            fill['S'] = color_wet

        gmt.pscoast(
            D=cres,
            W='thinnest,%s' % gmtpy.color(darken(gmtpy.color_tup(color_dry))),
            A=minarea,
            *(rivers+self._jxyr), **fill)

        if self.show_plates:
            self.draw_plates()

    def _draw_axes(self):
        gmt = self._gmt
        scaler = self._scaler
        widget = self._widget

        if self.axes_layout is None:
            if self.lat > 0.0:
                axes_layout = 'WSen'
            else:
                axes_layout = 'WseN'
        else:
            axes_layout = self.axes_layout

        scale_km = gmtpy.nice_value(self.radius/5.) / 1000.

        if self.show_center_mark:
            gmt.psxy(
                in_rows=[[self.lon, self.lat]],
                S='c20p', W='2p,black',
                *self._jxyr)

        if self.show_grid:
            btmpl = ('%(xinc)gg%(xinc)g:%(xlabel)s:/'
                     '%(yinc)gg%(yinc)g:%(ylabel)s:')
        else:
            btmpl = '%(xinc)g:%(xlabel)s:/%(yinc)g:%(ylabel)s:'

        gmt.psbasemap(
            B=(btmpl % scaler.get_params())+axes_layout,
            L=('x%gp/%gp/%g/%g/%gk' % (
                6./7*widget.width(),
                widget.height()/7.,
                self.lon,
                self.lat,
                scale_km)),
            *self._jxyr)

        if self.comment:
            font_size = self.gmt.label_font_size()

            _, east, south, _ = self._wesn
            if gmt.is_gmt5():
                row = [
                    1, 0,
                    '%gp,%s,%s' % (font_size, 0, 'black'), 'BR',
                    self.comment]

                farg = ['-F+f+j']
            else:
                row = [1, 0, font_size, 0, 0, 'BR', self.comment]
                farg = []

            gmt.pstext(
                in_rows=[row],
                N=True,
                R=(0, 1, 0, 1),
                D='%gp/%gp' % (-font_size*0.2, font_size*0.3),
                *(widget.PXY() + farg))

    def draw_axes(self):
        if not self._have_drawn_axes:
            self._draw_axes()
            self._have_drawn_axes = True

    def _have_coastlines(self):
        gmt = self._gmt
        cres = self._coastline_resolution
        minarea = self._minarea

        checkfile = gmt.tempfilename()

        gmt.pscoast(
            M=True,
            D=cres,
            W='thinnest,black',
            A=minarea,
            out_filename=checkfile,
            *self._jxyr)

        points = []
        with open(checkfile, 'r') as f:
            for line in f:
                ls = line.strip()
                if ls.startswith('#') or ls.startswith('>') or ls == '':
                    continue
                plon, plat = [float(x) for x in ls.split()]
                points.append((plat, plon))

        points = num.array(points, dtype=num.float)
        return num.any(points_in_region(points, self._wesn))

    def have_coastlines(self):
        self.gmt
        return self._have_coastlines()

    def project(self, lats, lons, jr=None):
        onepoint = False
        if isinstance(lats, float) and isinstance(lons, float):
            lats = [lats]
            lons = [lons]
            onepoint = True

        if jr is not None:
            j, r = jr
            gmt = gmtpy.GMT(version=self._gmtversion)
        else:
            j, _, _, r = self.jxyr
            gmt = self.gmt

        f = BytesIO()
        gmt.mapproject(j, r, in_columns=(lons, lats), out_stream=f, D='p')
        f.seek(0)
        data = num.loadtxt(f, ndmin=2)
        xs, ys = data.T
        if onepoint:
            xs = xs[0]
            ys = ys[0]
        return xs, ys

    def _map_box(self, jr=None):
        ll_lon, ll_lat, ur_lon, ur_lat = self._corners

        xs_corner, ys_corner = self.project(
            (ll_lat, ur_lat), (ll_lon, ur_lon), jr=jr)

        w = xs_corner[1] - xs_corner[0]
        h = ys_corner[1] - ys_corner[0]

        return w, h

    def _map_aspect(self, jr=None):
        w, h = self._map_box(jr=jr)
        return h/w

    def _draw_labels(self):
        points_taken = []
        regions_taken = []

        def no_points_in_rect(xs, ys, xmin, ymin, xmax, ymax):
            xx = not num.any(la(la(xmin < xs, xs < xmax),
                                la(ymin < ys, ys < ymax)))
            return xx

        def roverlaps(a, b):
            return (a[0] < b[2] and b[0] < a[2] and
                    a[1] < b[3] and b[1] < a[3])

        w, h = self._map_box()

        label_font_size = self.gmt.label_font_size()

        if self._labels:

            n = len(self._labels)

            lons, lats, texts, sx, sy, colors, fonts, font_sizes, styles = \
                list(zip(*self._labels))

            font_sizes = [
                (font_size or label_font_size) for font_size in font_sizes]

            sx = num.array(sx, dtype=num.float)
            sy = num.array(sy, dtype=num.float)

            xs, ys = self.project(lats, lons)

            points_taken.append((xs, ys))

            dxs = num.zeros(n)
            dys = num.zeros(n)

            for i in range(n):
                dx, dy = gmtpy.text_box(
                    texts[i],
                    font=fonts[i],
                    font_size=font_sizes[i],
                    **styles[i])

                dxs[i] = dx
                dys[i] = dy

            la = num.logical_and
            anchors_ok = (
                la(xs + sx + dxs < w, ys + sy + dys < h),
                la(xs - sx - dxs > 0., ys - sy - dys > 0.),
                la(xs + sx + dxs < w, ys - sy - dys > 0.),
                la(xs - sx - dxs > 0., ys + sy + dys < h),
            )

            arects = [
                (xs, ys, xs + sx + dxs, ys + sy + dys),
                (xs - sx - dxs, ys - sy - dys, xs, ys),
                (xs, ys - sy - dys, xs + sx + dxs, ys),
                (xs - sx - dxs, ys, xs, ys + sy + dys)]

            for i in range(n):
                for ianch in range(4):
                    anchors_ok[ianch][i] &= no_points_in_rect(
                        xs, ys, *[xxx[i] for xxx in arects[ianch]])

            anchor_choices = []
            anchor_take = []
            for i in range(n):
                choices = [ianch for ianch in range(4)
                           if anchors_ok[ianch][i]]
                anchor_choices.append(choices)
                if choices:
                    anchor_take.append(choices[0])
                else:
                    anchor_take.append(None)

            def cost(anchor_take):
                noverlaps = 0
                for i in range(n):
                    for j in range(n):
                        if i != j:
                            i_take = anchor_take[i]
                            j_take = anchor_take[j]
                            if i_take is None or j_take is None:
                                continue
                            r_i = [xxx[i] for xxx in arects[i_take]]
                            r_j = [xxx[j] for xxx in arects[j_take]]
                            if roverlaps(r_i, r_j):
                                noverlaps += 1

                return noverlaps

            cur_cost = cost(anchor_take)
            imax = 30
            while cur_cost != 0 and imax > 0:
                for i in range(n):
                    for t in anchor_choices[i]:
                        anchor_take_new = list(anchor_take)
                        anchor_take_new[i] = t
                        new_cost = cost(anchor_take_new)
                        if new_cost < cur_cost:
                            anchor_take = anchor_take_new
                            cur_cost = new_cost

                imax -= 1

            while cur_cost != 0:
                for i in range(n):
                    anchor_take_new = list(anchor_take)
                    anchor_take_new[i] = None
                    new_cost = cost(anchor_take_new)
                    if new_cost < cur_cost:
                        anchor_take = anchor_take_new
                        cur_cost = new_cost
                        break

            anchor_strs = ['BL', 'TR', 'TL', 'BR']

            for i in range(n):
                ianchor = anchor_take[i]
                color = colors[i]
                if color is None:
                    color = 'black'

                if ianchor is not None:
                    regions_taken.append([xxx[i] for xxx in arects[ianchor]])

                    anchor = anchor_strs[ianchor]

                    yoff = [-sy[i], sy[i]][anchor[0] == 'B']
                    xoff = [-sx[i], sx[i]][anchor[1] == 'L']
                    if self.gmt.is_gmt5():
                        row = (
                            lons[i], lats[i],
                            '%i,%s,%s' % (font_sizes[i], fonts[i], color),
                            anchor,
                            texts[i])

                        farg = ['-F+f+j']
                    else:
                        row = (
                            lons[i], lats[i],
                            font_sizes[i], 0, fonts[i], anchor,
                            texts[i])
                        farg = ['-G%s' % color]

                    self.gmt.pstext(
                        in_rows=[row],
                        D='%gp/%gp' % (xoff, yoff),
                        *(self.jxyr + farg),
                        **styles[i])

        if self._area_labels:

            for lons, lats, text, color, font, font_size, style in \
                    self._area_labels:

                if font_size is None:
                    font_size = label_font_size

                if color is None:
                    color = 'black'

                if self.gmt.is_gmt5():
                    farg = ['-F+f+j']
                else:
                    farg = ['-G%s' % color]

                xs, ys = self.project(lats, lons)
                dx, dy = gmtpy.text_box(
                    text, font=font, font_size=font_size, **style)

                rects = [xs-0.5*dx, ys-0.5*dy, xs+0.5*dx, ys+0.5*dy]

                locs_ok = num.ones(xs.size, dtype=num.bool)

                for iloc in range(xs.size):
                    rcandi = [xxx[iloc] for xxx in rects]

                    locs_ok[iloc] = True
                    locs_ok[iloc] &= (
                        0 < rcandi[0] and rcandi[2] < w
                        and 0 < rcandi[1] and rcandi[3] < h)

                    overlap = False
                    for r in regions_taken:
                        if roverlaps(r, rcandi):
                            overlap = True
                            break

                    locs_ok[iloc] &= not overlap

                    for xs_taken, ys_taken in points_taken:
                        locs_ok[iloc] &= no_points_in_rect(
                            xs_taken, ys_taken, *rcandi)

                        if not locs_ok[iloc]:
                            break

                rows = []
                for iloc, (lon, lat) in enumerate(zip(lons, lats)):
                    if not locs_ok[iloc]:
                        continue

                    if self.gmt.is_gmt5():
                        row = (
                            lon, lat,
                            '%i,%s,%s' % (font_size, font, color),
                            'MC',
                            text)

                    else:
                        row = (
                            lon, lat,
                            font_size, 0, font, 'MC',
                            text)

                    rows.append(row)

                    regions_taken.append([xxx[iloc] for xxx in rects])
                    break

                self.gmt.pstext(
                    in_rows=rows,
                    *(self.jxyr + farg),
                    **style)

    def draw_labels(self):
        self.gmt
        if not self._have_drawn_labels:
            self._draw_labels()
            self._have_drawn_labels = True

    def add_label(
            self, lat, lon, text,
            offset_x=5., offset_y=5.,
            color=None,
            font='1',
            font_size=None,
            style={}):

        if 'G' in style:
            style = style.copy()
            color = style.pop('G')

        self._labels.append(
            (lon, lat, text, offset_x, offset_y, color, font, font_size,
             style))

    def add_area_label(
            self, lat, lon, text,
            color=None,
            font='3',
            font_size=None,
            style={}):

        self._area_labels.append(
            (lon, lat, text, color, font, font_size, style))

    def cities_in_region(self):
        from pyrocko.dataset import geonames
        cities = geonames.get_cities_region(region=self.wesn, minpop=0)
        cities.extend(self.custom_cities)
        cities.sort(key=lambda x: x.population)
        return cities

    def draw_cities(self,
                    exact=None,
                    include=[],
                    exclude=[],
                    nmax_soft=10,
                    psxy_style=dict(S='s5p', G='black')):

        cities = self.cities_in_region()

        if exact is not None:
            cities = [c for c in cities if c.name in exact]
            minpop = None
        else:
            cities = [c for c in cities if c.name not in exclude]
            minpop = 10**3
            for minpop_new in [1e3, 3e3, 1e4, 3e4, 1e5, 3e5, 1e6, 3e6, 1e7]:
                cities_new = [
                    c for c in cities
                    if c.population > minpop_new or c.name in include]

                if len(cities_new) == 0 or (
                        len(cities_new) < 3 and len(cities) < nmax_soft*2):
                    break

                cities = cities_new
                minpop = minpop_new
                if len(cities) <= nmax_soft:
                    break

        if cities:
            lats = [c.lat for c in cities]
            lons = [c.lon for c in cities]

            self.gmt.psxy(
                in_columns=(lons, lats),
                *self.jxyr, **psxy_style)

            for c in cities:
                try:
                    text = c.name.encode('iso-8859-1').decode('iso-8859-1')
                except UnicodeEncodeError:
                    text = c.asciiname

                self.add_label(c.lat, c.lon, text)

        self._cities_minpop = minpop

    def add_stations(self, stations, psxy_style=dict(S='t8p', G='black')):
        lats = [s.lat for s in stations]
        lons = [s.lon for s in stations]

        self.gmt.psxy(
            in_columns=(lons, lats),
            *self.jxyr, **psxy_style)

        for station in stations:
            self.add_label(station.lat, station.lon, '.'.join(
                x for x in (station.network, station.station) if x))

    def add_kite_scene(self, scene):
        tile = FloatTile(
            scene.frame.llLon,
            scene.frame.llLat,
            scene.frame.dLon,
            scene.frame.dLat,
            scene.displacement)

        return tile

    def add_gnss_campaign(self, campaign,
                          psxy_style=dict(), labels=True):

        offsets = num.array([math.sqrt(s.east.shift**2 + s.north.shift**2)
                             for s in campaign.stations])
        scale = 1./offsets.max()

        default_psxy_style = {
            'h': 2,
            'h': 2,
            'W': '0.5p,black',
            't': '30',
            'G': 'black',
            'L': True,
            'S': 'e%d/0.95/8' % scale,
        }
        default_psxy_style.update(psxy_style)

        if labels:
            rows = [(s.lon, s.lat,
                     s.east.shift, s.north.shift,
                     s.east.sigma, s.north.sigma, 0)
                    for s in campaign.stations]
        else:
            rows = [(s.lon, s.lat,
                     s.east.shift, s.north.shift,
                     s.east.sigma, s.north.sigma, 0,
                     s.code)
                    for s in campaign.stations]

        kwargs = {}
        kwargs.update(default_psxy_style)
        kwargs.update(self.jxyr)

        self.gmt.psvelo(in_rows=rows, **kwargs)

    def draw_plates(self):
        from pyrocko.dataset import tectonics

        neast = 20
        nnorth = max(1, int(round(num.round(self._hreg/self._wreg * neast))))
        norths = num.linspace(-self._hreg*0.5, self._hreg*0.5, nnorth)
        easts = num.linspace(-self._wreg*0.5, self._wreg*0.5, neast)
        norths2 = num.repeat(norths, neast)
        easts2 = num.tile(easts, nnorth)
        lats, lons = od.ne_to_latlon(
            self.lat, self.lon, norths2, easts2)

        bird = tectonics.PeterBird2003()
        plates = bird.get_plates()

        color_plates = gmtpy.color('aluminium5')
        color_velocities = gmtpy.color('skyblue1')
        color_velocities_lab = gmtpy.color(darken(gmtpy.color_tup('skyblue1')))

        points = num.vstack((lats, lons)).T
        used = []
        for plate in plates:
            mask = plate.contains_points(points)
            if num.any(mask):
                used.append((plate, mask))

        if len(used) > 1:

            candi_fixed = {}

            label_data = []
            for plate, mask in used:

                mean_north = num.mean(norths2[mask])
                mean_east = num.mean(easts2[mask])
                iorder = num.argsort(num.sqrt(
                        (norths2[mask] - mean_north)**2 +
                        (easts2[mask] - mean_east)**2))

                lat_candis = lats[mask][iorder]
                lon_candis = lons[mask][iorder]

                candi_fixed[plate.name] = lat_candis.size

                label_data.append((
                    lat_candis, lon_candis, plate, color_plates))

            boundaries = bird.get_boundaries()

            size = 2

            psxy_kwargs = []

            for boundary in boundaries:
                if num.any(points_in_region(boundary.points, self._wesn)):
                    for typ, part in boundary.split_types(
                            [['SUB'],
                             ['OSR', 'OTF', 'OCB', 'CTF', 'CCB', 'CRB']]):

                        lats, lons = part.T

                        kwargs = {}
                        if typ[0] == 'SUB':
                            if boundary.kind == '\\':
                                kwargs['S'] = 'f%g/%gp+t+r' % (
                                    0.45*size, 3.*size)
                            elif boundary.kind == '/':
                                kwargs['S'] = 'f%g/%gp+t+l' % (
                                    0.45*size, 3.*size)

                            kwargs['G'] = color_plates

                        kwargs['in_columns'] = (lons, lats)
                        kwargs['W'] = '%gp,%s' % (size, color_plates),

                        psxy_kwargs.append(kwargs)

                        if boundary.kind == '\\':
                            if boundary.name2 in candi_fixed:
                                candi_fixed[boundary.name2] += neast*nnorth

                        elif boundary.kind == '/':
                            if boundary.name1 in candi_fixed:
                                candi_fixed[boundary.name1] += neast*nnorth

            candi_fixed = [name for name in sorted(
                list(candi_fixed.keys()), key=lambda name: -candi_fixed[name])]

            candi_fixed.append(None)

            gsrm = tectonics.GSRM1()

            for name in candi_fixed:
                if name not in gsrm.plate_names() \
                        and name not in gsrm.plate_alt_names():

                    continue

                lats, lons, vnorth, veast, vnorth_err, veast_err, corr = \
                    gsrm.get_velocities(name, region=self._wesn)

                fixed_plate_name = name

                self.gmt.psvelo(
                    in_columns=(
                        lons, lats, veast, vnorth, veast_err, vnorth_err,
                        corr),
                    W='0.25p,%s' % color_velocities,
                    A='9p+e+g%s' % color_velocities,
                    S='e0.2p/0.95/10',
                    *self.jxyr)

                for _ in range(len(lons) // 50 + 1):
                    ii = random.randint(0, len(lons)-1)
                    v = math.sqrt(vnorth[ii]**2 + veast[ii]**2)
                    self.add_label(
                        lats[ii], lons[ii], '%.0f' % v,
                        font_size=0.7*self.gmt.label_font_size(),
                        style=dict(
                            G=color_velocities_lab))

                break

            for (lat_candis, lon_candis, plate, color) in label_data:
                full_name = bird.full_name(plate.name)
                if plate.name == fixed_plate_name:
                    full_name = '@_' + full_name + '@_'

                self.add_area_label(
                    lat_candis, lon_candis,
                    full_name,
                    color=color,
                    font='3')

            for kwargs in psxy_kwargs:
                self.gmt.psxy(*self.jxyr, **kwargs)
示例#28
0
 class B(Object):
     a_list = List.T(A.T())
     a_tuple = Tuple.T(3, A.T())
     a_dict = Dict.T(Int.T(), A.T())
     b = Float.T()
示例#29
0
class BeamForming(Object):
    station_c = Station.T(optional=True)
    bazi = Float.T()
    slow = Float.T()
    diff_dt_treat = String.T(help='how to handle differing sampling rates:'
                             ' oversample(default) or downsample')
    normalize_std = Bool.T()
    post_normalize = Bool.T()
    t_shifts = Dict.T(String.T(), Float.T())

    def __init__(self,
                 stations,
                 traces,
                 normalize=True,
                 post_normalize=False,
                 diff_dt_treat='oversample'):
        self.stations = stations
        self.c_lat_lon_z = self.center_lat_lon(stations)
        self.traces = traces
        self.diff_dt_treat = diff_dt_treat
        self.normalize_std = normalize
        self.post_normalize = post_normalize
        self.station_c = None
        self.diff_dt_treat = diff_dt_treat

    def process(self,
                event,
                timing,
                bazi=None,
                slow=None,
                restitute=False,
                *args,
                **kwargs):
        '''
      :param timing: CakeTiming. Uses the definition without the offset.
      :param fn_dump_center: filename to where center stations shall be dumped
      :param fn_beam: filename of beam trace
      :param model: earthmodel to use(optional)
      :param earthmodel to use(optional)
      :param network: network code(optional)
      :param station: station code(optional)
        '''
        logger.debug('start beam forming')
        stations = self.stations
        network_code = kwargs.get('responses', None)
        network_code = kwargs.get('network', '')
        station_code = kwargs.get('station', 'STK')
        c_station_id = (network_code, station_code)
        t_shifts = []
        lat_c, lon_c, z_c = self.c_lat_lon_z

        self.station_c = Station(lat=float(lat_c),
                                 lon=float(lon_c),
                                 elevation=float(z_c),
                                 depth=0.,
                                 name='Array Center',
                                 network=c_station_id[0],
                                 station=c_station_id[1][:5])
        fn_dump_center = kwargs.get('fn_dump_center', 'array_center.pf')
        fn_beam = kwargs.get('fn_beam', 'beam.mseed')
        if event:
            mod = cake.load_model(crust2_profile=(event.lat, event.lon))
            dist = ortho.distance_accurate50m(event, self.station_c)
            ray = timing.t(mod, (event.depth, dist), get_ray=True)

            if ray is None:
                logger.error(
                    'None of defined phases available at beam station:\n %s' %
                    self.station_c)
                return
            else:
                b = ortho.azimuth(self.station_c, event)
                if b >= 0.:
                    self.bazi = b
                elif b < 0.:
                    self.bazi = 360. + b
                self.slow = ray.p / (cake.r2d * cake.d2m)
        else:
            self.bazi = bazi
            self.slow = slow

        logger.info(
            'stacking %s with slowness %1.4f s/km at back azimut %1.1f '
            'degrees' %
            ('.'.join(c_station_id), self.slow * cake.km, self.bazi))

        lat0 = num.array([lat_c] * len(stations))
        lon0 = num.array([lon_c] * len(stations))
        lats = num.array([s.lat for s in stations])
        lons = num.array([s.lon for s in stations])
        ns, es = ortho.latlon_to_ne_numpy(lat0, lon0, lats, lons)
        theta = num.float(self.bazi * num.pi / 180.)
        R = num.array([[num.cos(theta), -num.sin(theta)],
                       [num.sin(theta), num.cos(theta)]])
        distances = R.dot(num.vstack((es, ns)))[1]
        channels = set()
        self.stacked = {}
        num_stacked = {}
        self.t_shifts = {}
        self.shifted_traces = []
        taperer = trace.CosFader(xfrac=0.05)
        if self.diff_dt_treat == 'downsample':
            self.traces.sort(key=lambda x: x.deltat)
        elif self.diff_dt_treat == 'oversample':
            dts = [t.deltat for t in self.traces]
            for tr in self.traces:
                tr.resample(min(dts))

        for tr in self.traces:
            if tr.nslc_id[:2] == c_station_id:
                continue
            tr = tr.copy(data=True)
            tr.ydata = tr.ydata.astype(
                num.float64) - tr.ydata.mean(dtype=num.float64)
            tr.taper(taperer)
            try:
                stack_trace = self.stacked[tr.channel]
                num_stacked[tr.channel] += 1
            except KeyError:
                stack_trace = tr.copy(data=True)
                stack_trace.set_ydata(num.zeros(len(stack_trace.get_ydata())))

                stack_trace.set_codes(network=c_station_id[0],
                                      station=c_station_id[1],
                                      location='',
                                      channel=tr.channel)

                self.stacked[tr.channel] = stack_trace
                channels.add(tr.channel)
                num_stacked[tr.channel] = 1

            nslc_id = tr.nslc_id

            try:
                stats = list(
                    filter(
                        lambda x: util.match_nslc('%s.%s.%s.*' % x.nsl(),
                                                  nslc_id), stations))
                stat = stats[0]
            except IndexError:
                break

            i = stations.index(stat)
            d = distances[i]
            t_shift = d * self.slow
            t_shifts.append(t_shift)
            tr.shift(t_shift)
            self.t_shifts[tr.nslc_id[:2]] = t_shift
            if self.normalize_std:
                tr.ydata = tr.ydata / tr.ydata.std()

            if num.abs(tr.deltat - stack_trace.deltat) > 0.000001:
                if self.diff_dt_treat == 'downsample':
                    stack_trace.downsample_to(tr.deltat)
                elif self.diff_dt_treat == 'upsample':
                    raise Exception(
                        'something went wrong with the upsampling, previously')
            stack_trace.add(tr)
            self.shifted_traces.append(tr)

        if self.post_normalize:
            for ch, tr in self.stacked.items():
                tr.set_ydata(tr.get_ydata() / num_stacked[ch])

        self.save_station(fn_dump_center)
        self.checked_nslc([stack_trace])
        self.save(stack_trace, fn_beam)
        return self.shifted_traces, stack_trace, t_shifts

    def checked_nslc(self, trs):
        for tr in trs:
            oldids = tr.nslc_id
            n, s, l, c = oldids
            tr.set_codes(network=n[:2],
                         station=s[:5],
                         location=l[:2],
                         channel=c[:3])
            newids = tr.nslc_id
            if cmp(oldids, newids) != 0:
                logger.warn('nslc id truncated: %s to %s' %
                            ('.'.join(oldids), '.'.join(newids)))

    def snuffle(self):
        '''Scrutinize the shifted traces.'''
        from pyrocko import snuffler
        snuffler.snuffle(self.shifted_traces)

    def center_lat_lon(self, stations):
        '''Calculate a mean geographical centre of the array
        using spherical earth'''

        lats = num.zeros(len(stations))
        lons = num.zeros(len(stations))
        elevations = num.zeros(len(stations))
        depths = num.zeros(len(stations))
        for i, s in enumerate(stations):
            lats[i] = s.lat * torad
            lons[i] = s.lon * torad
            depths[i] = s.depth
            elevations[i] = s.elevation

        z = num.mean(elevations - depths)
        return (lats.mean() * 180 / num.pi, lons.mean() * 180 / num.pi, z)

    def plot(self, fn='beam_shifts.png'):
        stations = self.stations
        stations.append(self.station_c)
        res = to_cartesian(stations, self.station_c)
        center_xyz = res[self.station_c.nsl()[:2]]
        x = num.zeros(len(res))
        y = num.zeros(len(res))
        z = num.zeros(len(res))
        sizes = num.zeros(len(res))
        stat_labels = []
        i = 0
        for nsl, xyz in res.items():
            x[i] = xyz[0]
            y[i] = xyz[1]
            z[i] = xyz[2]

            try:
                sizes[i] = self.t_shifts[nsl[:2]]
                stat_labels.append('%s' % ('.'.join(nsl)))
            except AttributeError:
                self.fail('Run the snuffling first')
            except KeyError:
                stat_labels.append('%s' % ('.'.join(nsl)))
                continue
            finally:
                i += 1

        x /= 1000.
        y /= 1000.
        z /= 1000.
        xmax = x.max()
        xmin = x.min()
        ymax = y.max()
        ymin = y.min()

        fig = plt.figure()
        cax = fig.add_axes([0.85, 0.2, 0.05, 0.5])
        ax = fig.add_axes([0.10, 0.1, 0.70, 0.7])
        ax.set_aspect('equal')
        cmap = cm.get_cmap('bwr')
        ax.scatter(x,
                   y,
                   c=sizes,
                   s=200,
                   cmap=cmap,
                   vmax=num.max(sizes),
                   vmin=-num.max(sizes))
        for i, lab in enumerate(stat_labels):
            ax.text(x[i], y[i], lab, size=14)

        x = x[num.where(sizes == 0.)]
        y = y[num.where(sizes == 0.)]
        ax.scatter(x, y, c='black', alpha=0.4, s=200)

        ax.arrow(center_xyz[0] / 1000.,
                 center_xyz[1] / 1000.,
                 -num.sin(self.bazi / 180. * num.pi),
                 -num.cos(self.bazi / 180. * num.pi),
                 head_width=0.2,
                 head_length=0.2)
        ax.set_ylabel("N-S [km]")
        ax.set_xlabel("E-W [km]")
        ColorbarBase(cax,
                     cmap=cmap,
                     norm=Normalize(vmin=sizes.min(), vmax=sizes.max()))
        logger.debug('finished plotting %s' % fn)
        fig.savefig(fn)

    def save(self, traces, fn='beam.pf'):
        io.save(traces, fn)

    def save_station(self, fn):
        dump_stations([self.station_c], fn)
示例#30
0
class ProblemConfig(Object):
    """
    Config for optimization problem to setup.
    """
    mode = StringChoice.T(
        choices=['geometry', 'ffi', 'interseismic'],
        default='geometry',
        help='Problem to solve: "geometry", "ffi",'
        ' "interseismic"',
    )
    mode_config = ModeConfig.T(
        optional=True, help='Global optimization mode specific parameters.')
    source_type = StringChoice.T(
        default='RectangularSource',
        choices=source_names,
        help='Source type to optimize for. Options: %s' %
        (', '.join(name for name in source_names)))
    stf_type = StringChoice.T(
        default='HalfSinusoid',
        choices=stf_names,
        help='Source time function type to use. Options: %s' %
        (', '.join(name for name in stf_names)))
    decimation_factors = Dict.T(
        default=None,
        optional=True,
        help='Determines the reduction of discretization of an extended'
        ' source.')
    n_sources = Int.T(default=1, help='Number of Sub-sources to solve for')
    datatypes = List.T(default=['geodetic'])
    hyperparameters = Dict.T(
        help='Hyperparameters to weight different types of datatypes.')
    priors = Dict.T(help='Priors of the variables in question.')

    def __init__(self, **kwargs):

        mode = 'mode'
        mode_config = 'mode_config'
        if mode in kwargs:
            omode = kwargs[mode]

            if omode == 'ffi':
                if mode_config not in kwargs:
                    kwargs[mode_config] = FFIConfig()

        Object.__init__(self, **kwargs)

    def init_vars(self, variables=None):
        """
        Initiate priors based on the problem mode and datatypes.

        Parameters
        ----------
        variables : list
            of str of variable names to initialise
        """
        if variables is None:
            variables = self.select_variables()

        self.priors = OrderedDict()
        for variable in variables:
            if variable in block_vars:
                nvars = 1
            else:
                nvars = self.n_sources

            lower = default_bounds[variable][0]
            upper = default_bounds[variable][1]
            self.priors[variable] = \
                Parameter(
                    name=variable,
                    lower=num.ones(
                        nvars,
                        dtype=tconfig.floatX) * lower,
                    upper=num.ones(
                        nvars,
                        dtype=tconfig.floatX) * upper,
                    testvalue=num.ones(
                        nvars,
                        dtype=tconfig.floatX) * (lower + (upper / 5.)))

    def set_vars(self, bounds_dict):
        """
        Set variable bounds to given bounds.
        """
        for variable, bounds in bounds_dict.items():
            if variable in self.priors.keys():
                param = self.priors[variable]
                param.lower = num.atleast_1d(bounds[0])
                param.upper = num.atleast_1d(bounds[1])
                param.testvalue = num.atleast_1d(num.mean(bounds))
            else:
                logger.warning('Prior for variable %s does not exist!'
                               ' Bounds not updated!' % variable)

    def select_variables(self):
        """
        Return model variables depending on problem config.
        """

        if self.mode not in modes_catalog.keys():
            raise ValueError('Problem mode %s not implemented' % self.mode)

        vars_catalog = modes_catalog[self.mode]

        variables = []
        for datatype in self.datatypes:
            if datatype in vars_catalog.keys():
                if self.mode == 'geometry':
                    if self.source_type in vars_catalog[datatype].keys():
                        source = vars_catalog[datatype][self.source_type]
                        svars = set(source.keys())

                        if isinstance(source(),
                                      (PyrockoRS, gf.ExplosionSource)):
                            svars.discard('magnitude')

                        variables += utility.weed_input_rvs(
                            svars, self.mode, datatype)
                    else:
                        raise ValueError('Source Type not supported for type'
                                         ' of problem, and datatype!')

                    if datatype == 'seismic':
                        if self.stf_type in stf_catalog.keys():
                            stf = stf_catalog[self.stf_type]
                            variables += utility.weed_input_rvs(
                                set(stf.keys()), self.mode, datatype)
                else:
                    variables += vars_catalog[datatype]
            else:
                raise ValueError(
                    'Datatype %s not supported for type of'
                    ' problem! Supported datatype are: %s' %
                    (datatype, ', '.join('"%s"' % d
                                         for d in vars_catalog.keys())))

        unique_variables = utility.unique_list(variables)

        if len(unique_variables) == 0:
            raise Exception('Mode and datatype combination not implemented'
                            ' or not resolvable with given datatypes.')

        return unique_variables

    def get_slip_variables(self):
        """
        Return a list of slip variable names defined in the ProblemConfig.
        """
        if self.mode == 'ffi':
            return [
                var for var in static_dist_vars if var in self.priors.keys()
            ]
        elif self.mode == 'geometry':
            return [
                var for var in ['slip', 'magnitude']
                if var in self.priors.keys()
            ]
        elif self.mode == 'interseismic':
            return ['bl_amplitude']

    def set_decimation_factor(self):
        """
        Determines the reduction of discretization of an extended source.
        Influences yet only the RectangularSource.
        """
        if self.source_type == 'RectangularSource':
            self.decimation_factors = {}
            for datatype in self.datatypes:
                self.decimation_factors[datatype] = \
                    default_decimation_factors[datatype]
        else:
            self.decimation_factors = None

    def validate_priors(self):
        """
        Check if priors and their test values do not contradict!
        """
        for param in self.priors.itervalues():
            param.validate_bounds()

        logger.info('All parameter-priors ok!')

    def validate_hypers(self):
        """
        Check if hyperparameters and their test values do not contradict!
        """
        if self.hyperparameters is not None:
            for hp in self.hyperparameters.itervalues():
                hp.validate_bounds()

            logger.info('All hyper-parameters ok!')

        else:
            logger.info('No hyper-parameters defined!')