Example #1
0
def select_component_from_stream(st: obspy.core.Stream, component: str):
    """
    Helper function selecting a component from a Stream an raising the proper
    error if not found.

    This is a bit more flexible then stream.select() as it works with single
    letter channels and lowercase channels.

    :param st: Obspy stream
    :type st: obspy.core.Stream
    :param component: Name of component of stream
    :type component: str
    """
    component = component.upper()
    component = [tr for tr in st if tr.stats.channel[-1].upper() == component]
    if not component:
        raise LASIFNotFoundError(
            "Component %s not found in Stream." % component
        )
    elif len(component) > 1:
        raise LASIFNotFoundError(
            "More than 1 Trace with component %s found "
            "in Stream." % component
        )
    return component[0]
Example #2
0
    def get_project_function(self, fct_type: str):
        """
        Helper importing the project specific function.

        :param fct_type: The desired function.
        :type fct_type: str
        """
        # Cache to avoid repeated imports.
        if fct_type in self.__project_function_cache:
            return self.__project_function_cache[fct_type]

        # type / filename map
        fct_type_map = {
            "window_picking_function": "window_picking_function.py",
            "processing_function": "process_data.py",
            "preprocessing_function_asdf": "preprocessing_function_asdf.py",
            "process_synthetics": "process_synthetics.py",
            "source_time_function": "source_time_function.py",
            "light_preprocessing_function": "light_preprocessing.py",
        }

        if fct_type not in fct_type:
            msg = "Function '%s' not found. Available types: %s" % (
                fct_type,
                str(list(fct_type_map.keys())),
            )
            raise LASIFNotFoundError(msg)

        filename = os.path.join(
            self.paths["functions"], fct_type_map[fct_type]
        )
        if not os.path.exists(filename):
            msg = "No file '%s' in existence." % filename
            raise LASIFNotFoundError(msg)
        fct_template = importlib.machinery.SourceFileLoader(
            "_lasif_fct_template", filename
        ).load_module("_lasif_fct_template")

        try:
            fct = getattr(fct_template, fct_type)
        except AttributeError:
            raise LASIFNotFoundError(
                "Could not find function %s in file '%s'"
                % (fct_type, filename)
            )

        if not callable(fct):
            raise LASIFError(
                "Attribute %s in file '%s' is not a function."
                % (fct_type, filename)
            )

        # Add to cache.
        self.__project_function_cache[fct_type] = fct
        return fct
Example #3
0
    def get_available_synthetics(self, event_name: str):
        """
        Returns the available synthetics for a given event.

        :param event_name: The event name.
        :type event_name: str
        """
        data_dir = os.path.join(self._synthetics_folder, event_name)
        if not os.path.exists(data_dir):
            raise LASIFNotFoundError("No synthetic data for event '%s'." %
                                     event_name)
        iterations = []
        for folder in os.listdir(data_dir):
            if not os.path.isdir(
                    os.path.join(self._synthetics_folder, event_name,
                                 folder)) or not fnmatch.fnmatch(
                                     folder, "ITERATION_*"):
                continue
            iterations.append(folder)

        # Make sure the iterations also contain the event and the stations.
        its = []
        for iteration in iterations:
            try:
                it = self.comm.iterations.get(iteration)
            except LASIFNotFoundError:
                continue
            if event_name not in it.events:
                continue
            its.append(it.name)
        return its
Example #4
0
    def get(self, event_name: str) -> dict:
        """
        Get information about one event.
        This function uses multiple cache layers and is thus very cheap to
        call.
        :param event_name: The name of the event.
        :type event_name: str
        :rtype: dict
        """
        try:
            event_name = event_name["event_name"]
        except (KeyError, TypeError):
            pass

        if event_name not in self.all_events:
            raise LASIFNotFoundError("Event '%s' not known to LASIF." %
                                     event_name)

        if event_name not in self.__event_info_cache:
            values = dict(
                zip(
                    self.index_values,
                    self._extract_index_values_quakeml(
                        self.all_events[event_name]),
                ))
            values["origin_time"] = obspy.UTCDateTime(values["origin_time"])
            self.__event_info_cache[event_name] = values
        return self.__event_info_cache[event_name]
Example #5
0
    def get(self, weight_set_name: str):
        """
        Returns a weight_set object.

        :param iteration_name: The name of the iteration to retrieve.
        :type iteration_name: str
        """
        # Make it work with both the long and short version of the iteration
        # name, and existing iteration object.
        try:
            weight_set_name = str(weight_set_name.weight_set_name)
        except AttributeError:
            weight_set_name = str(weight_set_name)
            weight_set_name = weight_set_name.replace("WEIGHTS_", "")

        # Access cache.
        if weight_set_name in self.__cached_weights:
            return self.__cached_weights[weight_set_name]

        weights_dict = self.get_weight_set_dict()
        if weight_set_name not in weights_dict:
            msg = "Weights '%s' not found." % weight_set_name
            raise LASIFNotFoundError(msg)

        from lasif.weights_toml import WeightSet

        weight_set = WeightSet(weights_dict[weight_set_name])

        # Store in cache.
        self.__cached_weights[weight_set_name] = weight_set

        return weight_set
    def __init__(self, comm, iteration, event):
        self.comm = comm
        self.event = self.comm.events.get(event)
        self.iteration = self.comm.iterations.get(iteration)

        self.event_name = self.event["event_name"]

        if self.event_name not in self.iteration.events:
            msg = "Event '%s' not part of iteration '%s'." % (
                self.event_name,
                self.iteration.name,
            )
            raise LASIFNotFoundError(msg)

        # Get all stations defined for the given iteration and event.
        stations = set(
            self.iteration.events[self.event_name]["stations"].keys())

        # Only use those stations that actually have processed and synthetic
        # data available! Especially synthetics might not always be available.
        processed = comm.waveforms.get_metadata_processed(
            self.event_name, self.iteration.processing_tag)
        synthetics = comm.waveforms.get_metadata_synthetic(
            self.event_name, self.iteration)
        processed = set(
            ["%s.%s" % (_i["network"], _i["station"]) for _i in processed])
        synthetics = set(
            ["%s.%s" % (_i["network"], _i["station"]) for _i in synthetics])
        self.stations = tuple(
            sorted(stations.intersection(processed).intersection(synthetics)))

        self._current_index = -1
Example #7
0
 def remove_event(self, event_name):
     """Remove event from db"""
     # for now just remove the event
     if self.event_in_db(event_name):
         with self.sqlite_cursor() as c:
             c.execute("DELETE FROM events WHERE event_name=?",
                       (event_name, ))
     else:
         raise LASIFNotFoundError('Event: "{}" could not be '
                                  "found in database.".format(event_name))
Example #8
0
def plot_window_statistics(
    lasif_root,
    window_set: str,
    save: bool = False,
    events: Union[str, List[str]] = None,
    iteration: str = None,
):
    """
    Plot some statistics related to windows in a specific set.

    :param lasif_root: path to lasif root directory
    :type lasif_root: Union[str, pathlib.Path, object]
    :param window_set: name of window set
    :type window_set: str
    :param save: Saves the plot in a file, defaults to False,
    :type save: bool, optional
    :param events: An event or a list of events. To get all of them pass
        None, defaults to None
    :type events: Union[str, List[str]], optional
    :param iteration: Plot statistics related to events in a specific iteration
        , defaults to None
    :type iteration: str, optional
    """
    import matplotlib.pyplot as plt

    comm = find_project_comm(lasif_root)

    if events is None:
        events = comm.events.list(iteration=iteration)
    if isinstance(events, str):
        events = [events]

    if save:
        plt.switch_backend("agg")

    if not comm.windows.has_window_set(window_set):
        raise LASIFNotFoundError("Could not find the specified window set")

    comm.visualizations.plot_window_statistics(window_set,
                                               events,
                                               ax=None,
                                               show=False)

    if save:
        outfile = os.path.join(
            comm.project.get_output_folder(type="window_statistics_plots",
                                           tag="windows",
                                           timestamp=False),
            f"{window_set}.png",
        )
        plt.savefig(outfile, dpi=200, transparent=True)
        print("Saved picture at %s" % outfile)
    else:
        plt.show()
Example #9
0
 def get_event_id(self, event_name):
     """get event_id from database for a given event_name"""
     if self.event_in_db(event_name):
         with self.sqlite_cursor() as c:
             c.execute(
                 "SELECT event_id FROM events "
                 "WHERE event_name = ?",
                 (event_name, ),
             )
             return c.fetchone()[0]
     else:
         raise LASIFNotFoundError('Event: "{}" could not be'
                                  " found in database.".format(event_name))
Example #10
0
    def get_misfit_file(self, iteration: str):
        """
        Get path to the iteration misfit file

        :param iteration: Name of iteration
        :type iteration: str
        """
        iteration_name = self.comm.iterations.get_long_iteration_name(
            iteration)
        file = (self.comm.project.paths["iterations"] / iteration_name /
                "misfits.toml")
        if not os.path.exists(file):
            raise LASIFNotFoundError(f"File {file} does not exist")
        return file
Example #11
0
 def remove_trace(self, event_name, channel_name):
     """Remove trace, maybe check if window exist for this trace"""
     event_id = self.get_event_id(event_name)
     if self.trace_in_db(event_name, channel_name):
         with self.sqlite_cursor() as c:
             c.execute(
                 "DELETE FROM traces WHERE channel_name=? "
                 "AND event_id",
                 (event_name, event_id),
             )
             return c.fetchone()[0]
     else:
         raise LASIFNotFoundError("Trace {} - {} not found - could not be"
                                  " removed".format(event_name,
                                                    channel_name))
Example #12
0
    def _get_waveforms(
        self,
        event_name,
        station_id,
        data_type,
        tag_or_iteration=None,
        get_inventory=False,
    ):
        filename = self.get_asdf_filename(
            event_name=event_name,
            data_type=data_type,
            tag_or_iteration=tag_or_iteration,
        )

        if not os.path.exists(filename):
            raise LASIFNotFoundError("No '%s' waveform data found for event "
                                     "'%s' and station '%s'." %
                                     (data_type, event_name, station_id))

        with pyasdf.ASDFDataSet(filename, mode="r") as ds:
            station_group = ds.waveforms[station_id]

            tag = self._assert_tags(
                station_group=station_group,
                data_type=data_type,
                filename=filename,
            )

            # Get the waveform data.
            st = station_group[tag]

            # Make sure it only contains data from a single location.
            locs = sorted(set([tr.stats.location for tr in st]))
            if len(locs) != 1:
                msg = ("File '%s' contains %i location codes for station "
                       "'%s'. The alphabetically first one will be chosen." %
                       (filename, len(locs), station_id))
                warnings.warn(msg, LASIFWarning)

                st = st.filter(location=locs[0])

            if get_inventory:
                inv = station_group["StationXML"]
                return st, inv

            return st
Example #13
0
def clean_up(lasif_root, clean_up_file: str):
    """
    Clean up the lasif project. The required file can be created with
    the validate_data command.

    :param lasif_root: path to lasif root directory
    :type lasif_root: Union[str, pathlib.Path, object]
    :param clean_up_file: path to clean-up file
    :type clean_up_file: str
    """

    comm = find_project_comm(lasif_root)
    if not os.path.exists(clean_up_file):
        raise LASIFNotFoundError(f"Could not find {clean_up_file}\n"
                                 f"Please check that the specified file path "
                                 f"is correct.")

    comm.validator.clean_up_project(clean_up_file)
Example #14
0
def load_receivers(comm: object, event: str):
    """
    Loads receivers which have already been written into a json file

    :param comm: LASIF communicator object
    :type comm: object
    :param event: The name of the event for which to generate the
        input files.
    :type event: str
    """
    import json

    filename = (comm.project.paths["salvus_files"] / "RECEIVERS" / event /
                "receivers.json")
    if not os.path.exists(filename):
        raise LASIFNotFoundError()
    with open(filename, "r") as fh:
        receivers = json.load(fh)
    return receivers
Example #15
0
def clean_up(lasif_root,
             clean_up_file: str,
             delete_outofbounds_events: bool = False):
    """
    Clean up the lasif project. The required file can be created with
    the validate_data command.

    :param lasif_root: path to lasif root directory
    :type lasif_root: Union[str, pathlib.Path, object]
    :param clean_up_file: path to clean-up file
    :type clean_up_file: str
    :param delete_outofbounds_events: Whether full event files should be
        deleted if the event is out of the domain, defaults to False.
    :type delete_outofbounds_events: bool, optional
    """

    comm = find_project_comm(lasif_root)
    if not os.path.exists(clean_up_file):
        raise LASIFNotFoundError(f"Could not find {clean_up_file}\n"
                                 f"Please check that the specified file path "
                                 f"is correct.")

    comm.validator.clean_up_project(clean_up_file, delete_outofbounds_events)
Example #16
0
def compute_station_weights(
    lasif_root,
    weight_set: str,
    events: Union[str, List[str]] = None,
    iteration: str = None,
):
    """
    Compute weights for stations based on amount of neighbouring stations

    :param lasif_root: path to lasif root directory
    :type lasif_root: Union[str, pathlib.Path, object]
    :param weight_set: name of weight set to compute into
    :type weight_set: str
    :param events: An event or a list of events. To get all of them pass
        None, defaults to None
    :type events: Union[str, List[str]], optional
    :param iteration: name of iteration to compute weights for, controls
        the events which are picked for computing weights for.
    :type iteration: str, optional
    """

    comm = find_project_comm(lasif_root)

    if events is None:
        events = comm.events.list(iteration=iteration)
    if isinstance(events, str):
        events = [events]
    events_dict = {}
    if not comm.weights.has_weight_set(weight_set):
        for event in events:
            events_dict[event] = comm.query.get_all_stations_for_event(event)
        comm.weights.create_new_weight_set(
            weight_set_name=weight_set,
            events_dict=events_dict,
        )

    w_set = comm.weights.get(weight_set)
    from tqdm import tqdm

    for event in events:
        print(f"Calculating station weights for event: {event}")
        if not comm.events.has_event(event):
            raise LASIFNotFoundError(f"Event: {event} is not known to LASIF")
        stations = comm.query.get_all_stations_for_event(event)
        events_dict[event] = list(stations.keys())
        locations = np.zeros((2, len(stations.keys())), dtype=np.float64)
        for _i, station in enumerate(stations):
            locations[0, _i] = stations[station]["latitude"]
            locations[1, _i] = stations[station]["longitude"]

        sum_value = 0.0

        for station in tqdm(stations):
            weight = comm.weights.calculate_station_weight(
                lat_1=stations[station]["latitude"],
                lon_1=stations[station]["longitude"],
                locations=locations,
            )
            sum_value += weight
            w_set.events[event]["stations"][station]["station_weight"] = weight
        for station in stations:
            w_set.events[event]["stations"][station]["station_weight"] *= (
                len(stations) / sum_value)
        if len(stations.keys()) == 1:
            w_set.events[event]["stations"][
                stations[station]]["station_weight"] = 1.0

    comm.weights.change_weight_set(
        weight_set_name=weight_set,
        weight_set=w_set,
        events_dict=events_dict,
    )
Example #17
0
def calculate_adjoint_sources(
    lasif_root,
    iteration: str,
    window_set: str,
    weight_set: str = None,
    events: Union[str, List[str]] = None,
):
    """
    Calculate adjoint sources for a given iteration

    :param lasif_root: path to lasif root directory
    :type lasif_root: Union[str, pathlib.Path, object]
    :param iteration: name of iteration
    :type iteration: str
    :param window_set: name of window set
    :type window_set: str
    :param weight_set: name of station weight set, defaults to None
    :type weight_set: str, optional
    :param events: Name of event or list of events. To get all events for
        the iteration, pass None, defaults to None
    :type events: Union[str, List[str]]
    """
    from mpi4py import MPI

    comm = find_project_comm(lasif_root)

    # some basic checks
    if not comm.windows.has_window_set(window_set):
        if MPI.COMM_WORLD.rank == 0:
            raise LASIFNotFoundError(
                "Window set {} not known to LASIF".format(window_set))
        return

    if not comm.iterations.has_iteration(iteration):
        if MPI.COMM_WORLD.rank == 0:
            raise LASIFNotFoundError(
                "Iteration {} not known to LASIF".format(iteration))
        return

    if events is None:
        events = comm.events.list(iteration=iteration)
    if isinstance(events, str):
        events = [events]

    for _i, event in enumerate(events):
        if not comm.events.has_event(event):
            if MPI.COMM_WORLD.rank == 0:
                print("Event '%s' not known to LASIF. No adjoint sources for "
                      "this event will be calculated. " % event)
            continue

        if MPI.COMM_WORLD.rank == 0:
            print("\n{green}"
                  "==========================================================="
                  "{reset}".format(green=colorama.Fore.GREEN,
                                   reset=colorama.Style.RESET_ALL))
            print("Starting adjoint source calculation for event %i of "
                  "%i..." % (_i + 1, len(events)))
            print("{green}"
                  "==========================================================="
                  "{reset}\n".format(green=colorama.Fore.GREEN,
                                     reset=colorama.Style.RESET_ALL))

        # Get adjoint sources_filename
        # filename = comm.adj_sources.get_filename(
        #     event=event, iteration=iteration
        # )

        # remove adjoint sources if they already exist
        if MPI.COMM_WORLD.rank == 0:
            filename = comm.adj_sources.get_filename(event=event,
                                                     iteration=iteration)
            if os.path.exists(filename):
                os.remove(filename)

        MPI.COMM_WORLD.barrier()
        comm.adj_sources.calculate_adjoint_sources(event, iteration,
                                                   window_set)
        MPI.COMM_WORLD.barrier()
        if MPI.COMM_WORLD.rank == 0:
            comm.adj_sources.finalize_adjoint_sources(iteration, event,
                                                      weight_set)
Example #18
0
    def select_windows_multiprocessing(self,
                                       event: str,
                                       iteration_name: str,
                                       window_set_name: str,
                                       num_processes: int = 16,
                                       **kwargs):
        """
        Automatically select the windows for the given event and iteration.
        Uses Python's multiprocessing for parallelization.

        :param event: The event.
        :type event: str
        :param iteration_name: The iteration.
        :type iteration_name: str
        :param window_set_name: The name of the window set to pick into
        :type window_set_name: str
        :param num_processes: The number of processes used in multiprocessing
        :type num_processes: int
        """
        from lasif.utils import select_component_from_stream
        from tqdm import tqdm
        import multiprocessing
        import warnings
        import pyasdf
        warnings.filterwarnings("ignore")

        global _window_select

        event = self.comm.events.get(event)

        # Get the ASDF filenames.
        processed_filename = self.comm.waveforms.get_asdf_filename(
            event_name=event["event_name"],
            data_type="processed",
            tag_or_iteration=self.comm.waveforms.preprocessing_tag,
        )
        synthetic_filename = self.comm.waveforms.get_asdf_filename(
            event_name=event["event_name"],
            data_type="synthetic",
            tag_or_iteration=iteration_name,
        )

        if not os.path.exists(processed_filename):
            msg = "File '%s' does not exists." % processed_filename
            raise LASIFNotFoundError(msg)

        if not os.path.exists(synthetic_filename):
            msg = "File '%s' does not exists." % synthetic_filename
            raise LASIFNotFoundError(msg)

        # Load project specific window selection function.
        select_windows = self.comm.project.get_project_function(
            "window_picking_function")

        # Get source time function
        stf_fct = self.comm.project.get_project_function(
            "source_time_function")
        delta = self.comm.project.simulation_settings["time_step_in_s"]
        npts = self.comm.project.simulation_settings["number_of_time_steps"]
        freqmax = (
            1.0 / self.comm.project.simulation_settings["minimum_period_in_s"])
        freqmin = (
            1.0 / self.comm.project.simulation_settings["maximum_period_in_s"])
        stf_trace = stf_fct(npts=npts,
                            delta=delta,
                            freqmin=freqmin,
                            freqmax=freqmax)

        process_params = self.comm.project.simulation_settings
        minimum_period = process_params["minimum_period_in_s"]
        maximum_period = process_params["maximum_period_in_s"]

        def _window_select(station):
            ds = pyasdf.ASDFDataSet(processed_filename, mode="r", mpi=False)
            ds_synth = pyasdf.ASDFDataSet(synthetic_filename,
                                          mode="r",
                                          mpi=False)
            observed_station = ds.waveforms[station]
            synthetic_station = ds_synth.waveforms[station]

            obs_tag = observed_station.get_waveform_tags()
            syn_tag = synthetic_station.get_waveform_tags()

            try:
                # Make sure both have length 1.
                assert len(obs_tag) == 1, (
                    "Station: %s - Requires 1 observed waveform tag. Has %i." %
                    (observed_station._station_name, len(obs_tag)))
                assert len(syn_tag) == 1, (
                    "Station: %s - Requires 1 synthetic waveform tag. Has %i."
                    % (observed_station._station_name, len(syn_tag)))
            except AssertionError:
                return {station: None}

            obs_tag = obs_tag[0]
            syn_tag = syn_tag[0]

            # Finally get the data.
            st_obs = observed_station[obs_tag]
            st_syn = synthetic_station[syn_tag]

            # Extract coordinates once.
            coordinates = observed_station.coordinates

            # Process the synthetics.
            st_syn = self.comm.waveforms.process_synthetics(
                st=st_syn.copy(),
                event_name=event["event_name"],
                iteration=iteration_name,
            )

            all_windows = {}
            for component in ["E", "N", "Z"]:
                try:
                    data_tr = select_component_from_stream(st_obs, component)
                    synth_tr = select_component_from_stream(st_syn, component)

                    if self.comm.project.simulation_settings[
                            "scale_data_to_synthetics"]:
                        scaling_factor = (synth_tr.data.ptp() /
                                          data_tr.data.ptp())
                        # Store and apply the scaling.
                        data_tr.stats.scaling_factor = scaling_factor
                        data_tr.data *= scaling_factor

                except LASIFNotFoundError:
                    continue

                windows = None
                try:
                    windows = select_windows(
                        data_tr,
                        synth_tr,
                        stf_trace,
                        event["latitude"],
                        event["longitude"],
                        event["depth_in_km"],
                        coordinates["latitude"],
                        coordinates["longitude"],
                        minimum_period=minimum_period,
                        maximum_period=maximum_period,
                        iteration=iteration_name,
                        **kwargs,
                    )
                except Exception as e:
                    print(e)

                if not windows:
                    continue
                all_windows[data_tr.id] = windows

            if all_windows:
                return {station: all_windows}
            else:
                return {station: None}

        # Generate task list
        with pyasdf.ASDFDataSet(processed_filename, mode="r", mpi=False) as ds:
            task_list = ds.waveforms.list()

        # Use at most num_processes workers
        number_processes = min(num_processes, multiprocessing.cpu_count())

        # Open Pool of workers
        with multiprocessing.Pool(number_processes) as pool:
            results = {}
            with tqdm(total=len(task_list)) as pbar:
                for i, r in enumerate(
                        pool.imap_unordered(_window_select, task_list)):
                    pbar.update()
                    k, v = r.popitem()
                    results[k] = v

            pool.close()
            pool.join()

        # Write files with a single worker
        print("Finished window selection", flush=True)
        num_sta_with_windows = sum(v is not None for k, v in results.items())
        print(f"Writing windows for {num_sta_with_windows} out of "
              f"{len(task_list)} stations.")
        self.comm.windows.write_windows_to_sql(
            event_name=event["event_name"],
            windows=results,
            window_set_name=window_set_name,
        )
Example #19
0
    def plot_section(
        self,
        event_name: str,
        data_type: str = "processed",
        component: str = "Z",
        num_bins: int = 1,
        traces_per_bin: int = 500,
    ):
        """
        Create a section plot of an event and store the plot in Output. Useful
        for quickly inspecting if an event is good for usage.

        :param event_name: Name of the event
        :type event_name: str
        :param data_type: The type of data, one of: raw, processed (default)
        :type data_type: str, optional
        :param component: Component of the data Z(default), N, E
        :type component: str, optional
        :param num_bins: number of offset bins, defaults to 1
        :type num_bins: int, optional
        :param traces_per_bin: number of traces per bin, defaults to 500
        :type traces_per_bin: int, optional
        """
        import pyasdf
        import obspy

        from pathlib import Path

        event = self.comm.events.get(event_name)
        tag = self.comm.waveforms.preprocessing_tag
        asdf_filename = self.comm.waveforms.get_asdf_filename(
            event_name=event_name, data_type=data_type, tag_or_iteration=tag)

        asdf_file = Path(asdf_filename)
        if not asdf_file.is_file():
            raise LASIFNotFoundError(f"Could not find {asdf_file.name}")

        ds = pyasdf.ASDFDataSet(asdf_filename)

        # get event coords
        ev_coord = [event["latitude"], event["longitude"]]

        section_st = obspy.core.stream.Stream()
        for station in ds.waveforms.list():
            sta = ds.waveforms[station]
            st = obspy.core.stream.Stream()

            tags = sta.get_waveform_tags()
            if tags:
                st = sta[tags[0]]

            st = st.select(component=component)
            if len(st) > 0:
                st[0].stats["coordinates"] = sta.coordinates
                lat = sta.coordinates["latitude"]
                lon = sta.coordinates["longitude"]
                offset = np.sqrt((ev_coord[0] - lat)**2 +
                                 (ev_coord[1] - lon)**2)
                st[0].stats["offset"] = offset

            section_st += st

        if num_bins > 1:
            section_st = get_binned_stream(section_st,
                                           num_bins=num_bins,
                                           num_bin_tr=traces_per_bin)
        else:
            section_st = section_st[:traces_per_bin]

        outfile = os.path.join(
            self.comm.project.get_output_folder(type="section_plots",
                                                tag=event_name,
                                                timestamp=False),
            f"{tag}.png",
        )

        section_st.plot(
            type="section",
            dist_degree=True,
            ev_coord=ev_coord,
            scale=2.0,
            outfile=outfile,
        )
        print("Saved picture at %s" % outfile)
Example #20
0
    def _read(self):
        """
        Reads the HDF5 file and gathers basic information such as the
        coordinates of the edge nodes. In the case of domain that spans
        the entire earth, all points will lie inside the domain, therefore
        further processing is not necessary.
        """
        try:
            self.m = h5py.File(self.mesh_file, mode="r")
        except AssertionError:
            msg = (
                "Could not open the project's mesh file. "
                "Please ensure that the path specified "
                "in config is correct."
            )
            raise LASIFNotFoundError(msg)

        # if less than 2 side sets, this must be a global mesh.  Return
        self.side_set_names = list(self.m["SIDE_SETS"].keys())
        if (
            len(self.side_set_names) <= 2
            and "outer_boundary" not in self.side_set_names
        ):
            self.is_global_mesh = True
            self.min_lat = -90.0
            self.max_lat = 90.0
            self.min_lon = -180.0
            self.max_lon = 180.0
            return
        if "a0" in self.side_set_names:
            self.is_global_mesh = True
            self.min_lat = -90.0
            self.max_lat = 90.0
            self.min_lon = -180.0
            self.max_lon = 180.0
            return

        side_elements = []
        earth_surface_elements = []
        earth_bottom_elements = []
        for side_set in self.side_set_names:
            if side_set == "surface":
                continue
            elif side_set == "r0":
                earth_bottom_elements = self.m["SIDE_SETS"][side_set][
                    "elements"
                ][()]

            elif side_set == "r1":
                earth_surface_elements = self.m["SIDE_SETS"][side_set][
                    "elements"
                ][()]
            elif side_set == "r1_ol":
                earth_surface_elements = self.m["SIDE_SETS"][side_set][
                    "elements"
                ][()]

            else:
                side_elements.append(
                    self.m["SIDE_SETS"][side_set]["elements"][()]
                )

        side_elements_tmp = np.array([], dtype=np.int)
        for i in range(len(side_elements)):
            side_elements_tmp = np.concatenate(
                (side_elements_tmp, side_elements[i])
            )

        # Remove Duplicates
        side_elements = np.unique(side_elements_tmp)

        # Get node numbers of the nodes specifying the domain boundaries
        surface_boundaries = np.intersect1d(
            side_elements, earth_surface_elements
        )
        bottom_boundaries = np.intersect1d(
            side_elements, earth_bottom_elements
        )

        # Get coordinates
        coords = self.m["MODEL/coordinates"][()]
        self.domain_edge_coords = coords[surface_boundaries]
        self.earth_surface_coords = coords[earth_surface_elements]
        self.earth_bottom_coords = coords[earth_bottom_elements]
        self.bottom_edge_coords = coords[bottom_boundaries]

        # Get approximation of element width, take second smallest value

        # For now we will just take a random point on the surface and
        # take the maximum distance between gll points and use that
        # as the element with. It should be an overestimation
        x, y, z = self.earth_surface_coords[:, 0, :].T

        # # Get extent and center of domain
        # x, y, z = self.domain_edge_coords.T

        # # pick a random GLL point to represent the boundary
        # x = x[0]
        # y = y[0]
        # z = z[0]

        # get center lat/lon
        x_cen, y_cen, z_cen = np.median(x), np.median(y), np.median(z)
        self.center_lat, self.center_lon, _ = xyz_to_lat_lon_radius(
            x_cen, y_cen, z_cen
        )

        lats, lons, _ = xyz_to_lat_lon_radius(x, y, z)
        self.min_lat = np.min(lats)
        self.max_lat = np.max(lats)
        self.min_lon = np.min(lons)
        self.max_lon = np.max(lons)

        # Find point outside the domain:
        outside_point = self.find_outside_point()
        # Get coords for the bottom edge of mesh
        # x, y, z = self.bottom_edge_coords.T
        x, y, z = self.earth_bottom_coords.T
        x, y, z = x[0], y[0], z[0]

        # Figure out maximum depth of mesh
        _, _, r = xyz_to_lat_lon_radius(x, y, z)
        min_r = min(r)
        self.max_depth = self.r_earth - min_r

        self.is_read = True

        # In order to create the self.edge_polygon we need to make sure that
        # the points on the boundary are arranged in a way that a proper
        # polygon will be drawn.
        sorted_indices = self.get_sorted_edge_coords()
        x, y, z = self.domain_edge_coords[np.append(sorted_indices, 0)].T
        lats, lons, _ = xyz_to_lat_lon_radius(x[0], y[0], z[0])

        x, y, z = normalize_coordinates(x[0], y[0], z[0])
        points = np.array((x, y, z)).T

        self.boundary = np.array([lats, lons]).T
        self.edge_polygon = lasif.spherical_geometry.SphericalPolygon(
            points, outside_point
        )
        # Close file
        self.m.close()
Example #21
0
    def select_windows(self, event: str, iteration_name: str,
                       window_set_name: str, **kwargs):
        """
        Automatically select the windows for the given event and iteration.

        Function must be called with MPI.

        :param event: The event.
        :type event: str
        :param iteration_name: The iteration.
        :type iteration_name: str
        :param window_set_name: The name of the window set to pick into
        :type window_set_name: str
        """
        from lasif.utils import select_component_from_stream

        from mpi4py import MPI
        import pyasdf

        event = self.comm.events.get(event)

        # Get the ASDF filenames.
        processed_filename = self.comm.waveforms.get_asdf_filename(
            event_name=event["event_name"],
            data_type="processed",
            tag_or_iteration=self.comm.waveforms.preprocessing_tag,
        )
        synthetic_filename = self.comm.waveforms.get_asdf_filename(
            event_name=event["event_name"],
            data_type="synthetic",
            tag_or_iteration=iteration_name,
        )

        if not os.path.exists(processed_filename):
            msg = "File '%s' does not exists." % processed_filename
            raise LASIFNotFoundError(msg)

        if not os.path.exists(synthetic_filename):
            msg = "File '%s' does not exists." % synthetic_filename
            raise LASIFNotFoundError(msg)

        # Load project specific window selection function.
        select_windows = self.comm.project.get_project_function(
            "window_picking_function")

        # Get source time function
        stf_fct = self.comm.project.get_project_function(
            "source_time_function")
        delta = self.comm.project.simulation_settings["time_step_in_s"]
        npts = self.comm.project.simulation_settings["number_of_time_steps"]
        freqmax = (
            1.0 / self.comm.project.simulation_settings["minimum_period_in_s"])
        freqmin = (
            1.0 / self.comm.project.simulation_settings["maximum_period_in_s"])
        stf_trace = stf_fct(npts=npts,
                            delta=delta,
                            freqmin=freqmin,
                            freqmax=freqmax)

        process_params = self.comm.project.simulation_settings
        minimum_period = process_params["minimum_period_in_s"]
        maximum_period = process_params["maximum_period_in_s"]

        def process(observed_station, synthetic_station):
            obs_tag = observed_station.get_waveform_tags()
            syn_tag = synthetic_station.get_waveform_tags()

            # Make sure both have length 1.
            assert len(obs_tag) == 1, (
                "Station: %s - Requires 1 observed waveform tag. Has %i." %
                (observed_station._station_name, len(obs_tag)))
            assert len(syn_tag) == 1, (
                "Station: %s - Requires 1 synthetic waveform tag. Has %i." %
                (observed_station._station_name, len(syn_tag)))

            obs_tag = obs_tag[0]
            syn_tag = syn_tag[0]

            # Finally get the data.
            st_obs = observed_station[obs_tag]
            st_syn = synthetic_station[syn_tag]

            # Extract coordinates once.
            coordinates = observed_station.coordinates

            # Process the synthetics.
            st_syn = self.comm.waveforms.process_synthetics(
                st=st_syn.copy(),
                event_name=event["event_name"],
                iteration=iteration_name,
            )

            all_windows = {}

            for component in ["E", "N", "Z"]:
                try:
                    data_tr = select_component_from_stream(st_obs, component)
                    synth_tr = select_component_from_stream(st_syn, component)

                    if self.comm.project.simulation_settings[
                            "scale_data_to_synthetics"]:
                        scaling_factor = (synth_tr.data.ptp() /
                                          data_tr.data.ptp())
                        # Store and apply the scaling.
                        data_tr.stats.scaling_factor = scaling_factor
                        data_tr.data *= scaling_factor

                except LASIFNotFoundError:
                    continue

                windows = None
                try:
                    windows = select_windows(
                        data_tr,
                        synth_tr,
                        stf_trace,
                        event["latitude"],
                        event["longitude"],
                        event["depth_in_km"],
                        coordinates["latitude"],
                        coordinates["longitude"],
                        minimum_period=minimum_period,
                        maximum_period=maximum_period,
                        iteration=iteration_name,
                        **kwargs,
                    )
                except Exception as e:
                    print(e)

                if not windows:
                    continue
                all_windows[data_tr.id] = windows

            if all_windows:
                return all_windows

        ds = pyasdf.ASDFDataSet(processed_filename, mode="r", mpi=False)
        ds_synth = pyasdf.ASDFDataSet(synthetic_filename, mode="r", mpi=False)

        results = process_two_files_without_parallel_output(
            ds, ds_synth, process)
        MPI.COMM_WORLD.Barrier()
        # Write files on rank 0.
        if MPI.COMM_WORLD.rank == 0:
            print("Finished window selection", flush=True)
        size = MPI.COMM_WORLD.size
        MPI.COMM_WORLD.Barrier()
        for thread in range(size):
            rank = MPI.COMM_WORLD.rank
            if rank == thread:
                print(
                    f"Writing windows for rank: {rank+1} "
                    f"out of {size}",
                    flush=True,
                )
                self.comm.windows.write_windows_to_sql(
                    event_name=event["event_name"],
                    windows=results,
                    window_set_name=window_set_name,
                )
            MPI.COMM_WORLD.Barrier()
Example #22
0
    def calculate_adjoint_sources_multiprocessing(
        self,
        event: str,
        iteration: str,
        window_set_name: str,
        num_processes: int,
        plot: bool = False,
        **kwargs,
    ):
        """
        Calculate adjoint sources based on the type of misfit defined in
        the lasif config file.
        The computed misfit for each station is also written down into
        a misfit toml file.
        This function uses multiprocessing for parallelization

        :param event: Name of event
        :type event: str
        :param iteration: Name of iteration
        :type iteration: str
        :param window_set_name: Name of window set
        :type window_set_name: str
        :param num_processes: The number of processes used in multiprocessing
        :type num_processes: int
        :param plot: Should the adjoint source be plotted?, defaults to False
        :type plot: bool, optional
        """
        from lasif.utils import select_component_from_stream
        from tqdm import tqdm
        import multiprocessing
        import warnings
        warnings.filterwarnings("ignore")

        # Globally define the processing function. This is required to enable
        # pickling of a function within a function. Alternatively, a solution
        # can be found that does not utilize a function within a function.
        global _process
        event = self.comm.events.get(event)

        # Get the ASDF filenames.
        processed_filename = self.comm.waveforms.get_asdf_filename(
            event_name=event["event_name"],
            data_type="processed",
            tag_or_iteration=self.comm.waveforms.preprocessing_tag,
        )
        synthetic_filename = self.comm.waveforms.get_asdf_filename(
            event_name=event["event_name"],
            data_type="synthetic",
            tag_or_iteration=iteration,
        )

        if not os.path.exists(processed_filename):
            msg = "File '%s' does not exists." % processed_filename
            raise LASIFNotFoundError(msg)

        if not os.path.exists(synthetic_filename):
            msg = "File '%s' does not exists." % synthetic_filename
            raise LASIFNotFoundError(msg)

        all_windows = self.comm.windows.read_all_windows(
            event=event["event_name"], window_set_name=window_set_name)

        process_params = self.comm.project.simulation_settings

        def _process(station):
            ds = pyasdf.ASDFDataSet(processed_filename, mode="r", mpi=False)
            ds_synth = pyasdf.ASDFDataSet(synthetic_filename,
                                          mode="r",
                                          mpi=False)
            observed_station = ds.waveforms[station]
            synthetic_station = ds_synth.waveforms[station]

            # print(observed_station, synthetic_station)
            obs_tag = observed_station.get_waveform_tags()
            syn_tag = synthetic_station.get_waveform_tags()

            adjoint_sources = {}
            try:
                # Make sure both have length 1.
                assert len(obs_tag) == 1, (
                    "Station: %s - Requires 1 observed waveform tag. Has %i." %
                    (observed_station._station_name, len(obs_tag)))
                assert len(syn_tag) == 1, (
                    "Station: %s - Requires 1 synthetic waveform tag. Has %i."
                    % (observed_station._station_name, len(syn_tag)))
            except AssertionError:
                return {station: adjoint_sources}

            obs_tag = obs_tag[0]
            syn_tag = syn_tag[0]

            # Finally get the data.
            st_obs = observed_station[obs_tag]
            st_syn = synthetic_station[syn_tag]

            # Process the synthetics.
            st_syn = self.comm.waveforms.process_synthetics(
                st=st_syn.copy(),
                event_name=event["event_name"],
                iteration=iteration,
            )

            ad_src_type = self.comm.project.optimization_settings[
                "misfit_type"]
            if ad_src_type == "weighted_waveform_misfit":
                env_scaling = True
                ad_src_type = "waveform_misfit"
            else:
                env_scaling = False

            for component in ["E", "N", "Z"]:
                try:
                    data_tr = select_component_from_stream(st_obs, component)
                    synth_tr = select_component_from_stream(st_syn, component)
                except LASIFNotFoundError:
                    continue

                if self.comm.project.simulation_settings[
                        "scale_data_to_synthetics"]:
                    if (not self.comm.project.
                            optimization_settings["misfit_type"]
                            == "L2NormWeighted"):
                        scaling_factor = (synth_tr.data.ptp() /
                                          data_tr.data.ptp())
                        # Store and apply the scaling.
                        data_tr.stats.scaling_factor = scaling_factor
                        data_tr.data *= scaling_factor

                net, sta, cha = data_tr.id.split(".", 2)
                station = net + "." + sta

                if station not in all_windows:
                    continue
                if data_tr.id not in all_windows[station]:
                    continue
                # Collect all.
                windows = all_windows[station][data_tr.id]
                try:
                    # for window in windows:
                    asrc = calculate_adjoint_source(
                        observed=data_tr,
                        synthetic=synth_tr,
                        window=windows,
                        min_period=process_params["minimum_period_in_s"],
                        max_period=process_params["maximum_period_in_s"],
                        adj_src_type=ad_src_type,
                        window_set=window_set_name,
                        taper_ratio=0.15,
                        taper_type="cosine",
                        plot=plot,
                        envelope_scaling=env_scaling,
                    )
                except:
                    # Either pass or fail for the whole component.
                    continue

                if not asrc:
                    continue
                # Sum up both misfit, and adjoint source.
                misfit = asrc.misfit
                adj_source = asrc.adjoint_source.data

                adjoint_sources[data_tr.id] = {
                    "misfit": misfit,
                    "adj_source": adj_source,
                }
            adj_dict = {station: adjoint_sources}
            return adj_dict

        ds = pyasdf.ASDFDataSet(processed_filename, mode="r", mpi=False)

        # Generate task list
        task_list = ds.waveforms.list()

        # Use at most num_processes
        number_processes = min(num_processes, multiprocessing.cpu_count())

        with multiprocessing.Pool(number_processes) as pool:
            results = {}
            with tqdm(total=len(task_list)) as pbar:
                for i, r in enumerate(pool.imap_unordered(_process,
                                                          task_list)):
                    pbar.update()
                    k, v = r.popitem()
                    results[k] = v

            pool.close()
            pool.join()

        # Write adjoint sources
        filename = self.get_filename(event=event["event_name"],
                                     iteration=iteration)
        long_iter_name = self.comm.iterations.get_long_iteration_name(
            iteration)
        misfit_toml = self.comm.project.paths["iterations"]
        toml_filename = misfit_toml / long_iter_name / "misfits.toml"

        ad_src_counter = 0
        if os.path.exists(toml_filename):
            iteration_misfits = toml.load(toml_filename)
            if event["event_name"] in iteration_misfits.keys():
                iteration_misfits[event["event_name"]]["event_misfit"] = 0.0
            with open(toml_filename, "w") as fh:
                toml.dump(iteration_misfits, fh)

        print("Writing adjoint sources...")
        with pyasdf.ASDFDataSet(filename=filename, mpi=False, mode="a") as bs:
            if toml_filename.exists():
                iteration_misfits = toml.load(toml_filename)
                if event["event_name"] in iteration_misfits.keys():
                    total_misfit = iteration_misfits[
                        event["event_name"]]["event_misfit"]
                else:
                    iteration_misfits[event["event_name"]] = {}
                    iteration_misfits[event["event_name"]]["stations"] = {}
                    total_misfit = 0.0
            else:
                iteration_misfits = {}
                iteration_misfits[event["event_name"]] = {}
                iteration_misfits[event["event_name"]]["stations"] = {}
                total_misfit = 0.0
            for value in results.values():
                if not value:
                    continue
                station_misfit = 0.0
                for c_id, adj_source in value.items():
                    net, sta, loc, cha = c_id.split(".")

                    bs.add_auxiliary_data(
                        data=adj_source["adj_source"],
                        data_type="AdjointSources",
                        path="%s_%s/Channel_%s_%s" % (net, sta, loc, cha),
                        parameters={"misfit": adj_source["misfit"]},
                    )
                    station_misfit += adj_source["misfit"]
                    station_name = f"{net}.{sta}"
                iteration_misfits[event["event_name"]]["stations"][
                    station_name] = float(station_misfit)
                ad_src_counter += 1
                total_misfit += station_misfit
            iteration_misfits[event["event_name"]]["event_misfit"] = float(
                total_misfit)
            with open(toml_filename, "w") as fh:
                toml.dump(iteration_misfits, fh)

        with pyasdf.ASDFDataSet(filename=filename, mpi=False, mode="a") as ds:
            length = len(ds.auxiliary_data.AdjointSources.list())
        print(f"{length} Adjoint sources are in your file.")
Example #23
0
    def calculate_adjoint_sources(
        self,
        event: str,
        iteration: str,
        window_set_name: str,
        plot: bool = False,
        **kwargs,
    ):
        """
        Calculate adjoint sources based on the type of misfit defined in
        the lasif config file.
        The computed misfit for each station is also written down into
        a misfit toml file.

        :param event: Name of event
        :type event: str
        :param iteration: Name of iteration
        :type iteration: str
        :param window_set_name: Name of window set
        :type window_set_name: str
        :param plot: Should the adjoint source be plotted?, defaults to False
        :type plot: bool, optional
        """
        from lasif.utils import select_component_from_stream

        from mpi4py import MPI
        import pyasdf

        event = self.comm.events.get(event)

        # Get the ASDF filenames.
        processed_filename = self.comm.waveforms.get_asdf_filename(
            event_name=event["event_name"],
            data_type="processed",
            tag_or_iteration=self.comm.waveforms.preprocessing_tag,
        )
        synthetic_filename = self.comm.waveforms.get_asdf_filename(
            event_name=event["event_name"],
            data_type="synthetic",
            tag_or_iteration=iteration,
        )

        if not os.path.exists(processed_filename):
            msg = "File '%s' does not exists." % processed_filename
            raise LASIFNotFoundError(msg)

        if not os.path.exists(synthetic_filename):
            msg = "File '%s' does not exists." % synthetic_filename
            raise LASIFNotFoundError(msg)

        # Read all windows on rank 0 and broadcast.
        if MPI.COMM_WORLD.rank == 0:
            all_windows = self.comm.windows.read_all_windows(
                event=event["event_name"], window_set_name=window_set_name)
        else:
            all_windows = {}
        all_windows = MPI.COMM_WORLD.bcast(all_windows, root=0)

        process_params = self.comm.project.simulation_settings

        def process(observed_station, synthetic_station):
            obs_tag = observed_station.get_waveform_tags()
            syn_tag = synthetic_station.get_waveform_tags()

            # Make sure both have length 1.
            assert len(obs_tag) == 1, (
                "Station: %s - Requires 1 observed waveform tag. Has %i." %
                (observed_station._station_name, len(obs_tag)))
            assert len(syn_tag) == 1, (
                "Station: %s - Requires 1 synthetic waveform tag. Has %i." %
                (observed_station._station_name, len(syn_tag)))

            obs_tag = obs_tag[0]
            syn_tag = syn_tag[0]

            # Finally get the data.
            st_obs = observed_station[obs_tag]
            st_syn = synthetic_station[syn_tag]

            # Process the synthetics.
            st_syn = self.comm.waveforms.process_synthetics(
                st=st_syn.copy(),
                event_name=event["event_name"],
                iteration=iteration,
            )

            adjoint_sources = {}
            ad_src_type = self.comm.project.optimization_settings[
                "misfit_type"]
            if ad_src_type == "weighted_waveform_misfit":
                env_scaling = True
                ad_src_type = "waveform_misfit"
            else:
                env_scaling = False

            for component in ["E", "N", "Z"]:
                try:
                    data_tr = select_component_from_stream(st_obs, component)
                    synth_tr = select_component_from_stream(st_syn, component)
                except LASIFNotFoundError:
                    continue

                if self.comm.project.simulation_settings[
                        "scale_data_to_synthetics"]:
                    if (not self.comm.project.
                            optimization_settings["misfit_type"]
                            == "L2NormWeighted"):
                        scaling_factor = (synth_tr.data.ptp() /
                                          data_tr.data.ptp())
                        # Store and apply the scaling.
                        data_tr.stats.scaling_factor = scaling_factor
                        data_tr.data *= scaling_factor

                net, sta, cha = data_tr.id.split(".", 2)
                station = net + "." + sta

                if station not in all_windows:
                    continue
                if data_tr.id not in all_windows[station]:
                    continue
                # Collect all.
                windows = all_windows[station][data_tr.id]
                try:
                    # for window in windows:
                    asrc = calculate_adjoint_source(
                        observed=data_tr,
                        synthetic=synth_tr,
                        window=windows,
                        min_period=process_params["minimum_period_in_s"],
                        max_period=process_params["maximum_period_in_s"],
                        adj_src_type=ad_src_type,
                        window_set=window_set_name,
                        taper_ratio=0.15,
                        taper_type="cosine",
                        plot=plot,
                        envelope_scaling=env_scaling,
                    )
                except:
                    # Either pass or fail for the whole component.
                    continue

                if not asrc:
                    continue
                # Sum up both misfit, and adjoint source.
                misfit = asrc.misfit
                adj_source = asrc.adjoint_source.data

                adjoint_sources[data_tr.id] = {
                    "misfit": misfit,
                    "adj_source": adj_source,
                }

            return adjoint_sources

        ds = pyasdf.ASDFDataSet(processed_filename, mode="r", mpi=False)
        ds_synth = pyasdf.ASDFDataSet(synthetic_filename, mode="r", mpi=False)

        # Launch the processing. This will be executed in parallel across
        # ranks.
        results = process_two_files_without_parallel_output(
            ds, ds_synth, process)
        # Write files on all ranks.
        filename = self.get_filename(event=event["event_name"],
                                     iteration=iteration)
        long_iter_name = self.comm.iterations.get_long_iteration_name(
            iteration)
        misfit_toml = self.comm.project.paths["iterations"]
        toml_filename = misfit_toml / long_iter_name / "misfits.toml"

        ad_src_counter = 0
        size = MPI.COMM_WORLD.size
        if MPI.COMM_WORLD.rank == 0:
            if os.path.exists(toml_filename):
                iteration_misfits = toml.load(toml_filename)
                if event["event_name"] in iteration_misfits.keys():
                    iteration_misfits[
                        event["event_name"]]["event_misfit"] = 0.0
                with open(toml_filename, "w") as fh:
                    toml.dump(iteration_misfits, fh)
        MPI.COMM_WORLD.Barrier()
        for thread in range(size):
            rank = MPI.COMM_WORLD.rank
            if rank == thread:
                print(
                    f"Writing adjoint sources for rank: {rank+1} "
                    f"out of {size}",
                    flush=True,
                )
                with pyasdf.ASDFDataSet(filename=filename, mpi=False,
                                        mode="a") as bs:
                    if toml_filename.exists():
                        iteration_misfits = toml.load(toml_filename)
                        if event["event_name"] in iteration_misfits.keys():
                            total_misfit = iteration_misfits[
                                event["event_name"]]["event_misfit"]
                        else:
                            iteration_misfits[event["event_name"]] = {}
                            iteration_misfits[
                                event["event_name"]]["stations"] = {}
                            total_misfit = 0.0
                    else:
                        iteration_misfits = {}
                        iteration_misfits[event["event_name"]] = {}
                        iteration_misfits[event["event_name"]]["stations"] = {}
                        total_misfit = 0.0
                    for value in results.values():
                        if not value:
                            continue
                        station_misfit = 0.0
                        for c_id, adj_source in value.items():
                            net, sta, loc, cha = c_id.split(".")

                            bs.add_auxiliary_data(
                                data=adj_source["adj_source"],
                                data_type="AdjointSources",
                                path="%s_%s/Channel_%s_%s" %
                                (net, sta, loc, cha),
                                parameters={"misfit": adj_source["misfit"]},
                            )
                            station_misfit += adj_source["misfit"]
                            station_name = f"{net}.{sta}"
                        iteration_misfits[event["event_name"]]["stations"][
                            station_name] = float(station_misfit)
                        ad_src_counter += 1
                        total_misfit += station_misfit
                    iteration_misfits[event["event_name"]][
                        "event_misfit"] = float(total_misfit)
                    with open(toml_filename, "w") as fh:
                        toml.dump(iteration_misfits, fh)

            MPI.COMM_WORLD.barrier()
        if MPI.COMM_WORLD.rank == 0:
            with pyasdf.ASDFDataSet(filename=filename, mpi=False,
                                    mode="a") as ds:
                length = len(ds.auxiliary_data.AdjointSources.list())
            print(f"{length} Adjoint sources are in your file.")
Example #24
0
def compare_misfits(
    lasif_root,
    from_it: str,
    to_it: str,
    events: Union[str, List[str]] = None,
    weight_set: str = None,
    print_events: bool = False,
):
    """
    Compares the total misfit between two iterations.

    Total misfit is used regardless of the similarity of the picked windows
    from each iteration. This might skew the results but should
    give a good idea unless the windows change excessively between
    iterations.

    If windows are weighted in the calculation of the adjoint
    sources. That should translate into the calculated misfit
    value.

    :param lasif_root: path to lasif root directory
    :type lasif_root: Union[str, pathlib.Path, object]
    :param from_it: evaluate misfit from this iteration
    :type from_it: str
    :param to_it: to this iteration
    :type to_it: str
    :param events: An event or a list of events. To get all of them pass
        None, defaults to None
    :type events: Union[str, List[str]], optional
    :param weight_set: Set of station and event weights, defaults to None
    :type weight_set: str, optional
    :param print_events: compare misfits for each event, defaults to False
    :type print_events: bool, optional
    """
    comm = find_project_comm(lasif_root)

    if events is None:
        events = comm.events.list()
    if isinstance(events, str):
        events = [events]

    if weight_set:
        if not comm.weights.has_weight_set(weight_set):
            raise LASIFNotFoundError(f"Weights {weight_set} not known"
                                     f"to LASIF")
    # Check if iterations exist
    if not comm.iterations.has_iteration(from_it):
        raise LASIFNotFoundError(f"Iteration {from_it} not known to LASIF")
    if not comm.iterations.has_iteration(to_it):
        raise LASIFNotFoundError(f"Iteration {to_it} not known to LASIF")

    from_it_misfit = 0.0
    to_it_misfit = 0.0
    for event in events:
        from_it_misfit += float(
            comm.adj_sources.get_misfit_for_event(event, from_it, weight_set))
        to_it_misfit += float(
            comm.adj_sources.get_misfit_for_event(event, to_it, weight_set))
        if print_events:
            # Print information about every event.
            from_it_misfit_event = float(
                comm.adj_sources.get_misfit_for_event(event, from_it,
                                                      weight_set))
            to_it_misfit_event = float(
                comm.adj_sources.get_misfit_for_event(event, to_it,
                                                      weight_set))
            print(f"{event}: \n"
                  f"\t iteration {from_it} has misfit: "
                  f"{from_it_misfit_event} \n"
                  f"\t iteration {to_it} has misfit: {to_it_misfit_event}.")
            rel_change = ((to_it_misfit_event - from_it_misfit_event) /
                          from_it_misfit_event * 100.0)
            print(f"Relative change: {rel_change:.2f}%")

    print(f"Total misfit for iteration {from_it}: {from_it_misfit}")
    print(f"Total misfit for iteration {to_it}: {to_it_misfit}")
    rel_change = (to_it_misfit - from_it_misfit) / from_it_misfit * 100.0
    if rel_change > 0.0:
        print(f"Misfit has increased {rel_change:.2f}% from iteration "
              f"{from_it} to iteration {to_it}.")
    else:
        print(f"Misfit has decreased {-rel_change:.2f}% from iteration "
              f"{from_it} to iteration {to_it}")
    n_events = len(comm.events.list())
    print(f"Misfit per event for iteration {from_it}: "
          f"{from_it_misfit/n_events}")
    print(f"Misfit per event for iteration {to_it}: "
          f"{to_it_misfit/n_events}")
def get_subset_of_events(comm, count, events, existing_events=None):
    """
    This function gets an optimally distributed set of events,
    NO QA.
    :param comm: LASIF communicator
    :param count: number of events to choose.
    :param events: list of event_names, from which to choose from. These
    events must be known to LASIF
    :param existing_events: list of events, that have been chosen already
    and should thus be excluded from the selected options, but are also
    taken into account when ensuring a good spatial distribution. The
    function assumes that there are no common occurences between
    events and existing events
    :return: a list of chosen events.
    """
    available_events = comm.events.list()

    if len(events) < count:
        raise LASIFError("Insufficient amount of events specified.")
    if not type(count) == int:
        raise ValueError("count should be an integer value.")
    if count < 1:
        raise ValueError("count should be at least 1.")
    for event in events:
        if event not in available_events:
            raise LASIFNotFoundError(f"event : {event} not known to LASIF.")

    if existing_events is None:
        existing_events = []
    else:
        for event in events:
            if event in existing_events:
                raise LASIFError(f"event: {event} was existing already,"
                                 f"but still supplied to choose from.")

    cat = obspy.Catalog()
    for event in events:
        event_file_name = comm.waveforms.get_asdf_filename(event,
                                                           data_type="raw")
        with pyasdf.ASDFDataSet(event_file_name, mode="r") as ds:
            ev = ds.events[0]
            # append event_name to comments, such that it can later be
            # retrieved
            ev.comments.append(event)
            cat += ev

    # Coordinates and the Catalog will have the same order!
    coordinates = []
    for event in cat:
        org = event.preferred_origin() or event.origins[0]
        coordinates.append((org.latitude, org.longitude))

    chosen_events = []
    existing_coordinates = []
    for event in existing_events:
        ev = comm.events.get(event)
        existing_coordinates.append((ev["latitude"], ev["longitude"]))

    # randomly start with one of the specified events
    if not existing_coordinates:
        idx = random.randint(0, len(cat) - 1)
        chosen_events.append(cat[idx])
        del cat.events[idx]
        existing_coordinates.append(coordinates[idx])
        del coordinates[idx]
        count -= 1

    while count:
        if not coordinates:
            print("\tNo events left to select from. Stopping here.")
            break
        # Build kdtree and query for the point furthest away from any other
        # point.
        kdtree = SphericalNearestNeighbour(np.array(existing_coordinates))
        distances = kdtree.query(np.array(coordinates), k=1)[0]
        idx = np.argmax(distances)

        event = cat[idx]
        coods = coordinates[idx]
        del cat.events[idx]
        del coordinates[idx]

        chosen_events.append(event)
        existing_coordinates.append(coods)
        count -= 1

    list_of_chosen_events = []
    for ev in chosen_events:
        list_of_chosen_events.append(ev.comments.pop())
    if len(list_of_chosen_events) < count:
        raise ValueError("Could not select a sufficient amount of events")

    return list_of_chosen_events
Example #26
0
    def select_windows_for_station(
        self,
        event: str,
        iteration: str,
        station: str,
        window_set_name: str,
        **kwargs,
    ):
        """
        Selects windows for the given event, iteration, and station. Will
        delete any previously existing windows for that station if any.

        :param event: The event.
        :type event: str
        :param iteration: The iteration name.
        :type iteration: str
        :param station: The station id in the form NET.STA.
        :type station: str
        :param window_set_name: Name of window set
        :type window_set_name: str
        """
        from lasif.utils import select_component_from_stream

        # Load project specific window selection function.
        select_windows = self.comm.project.get_project_function(
            "window_picking_function")

        event = self.comm.events.get(event)
        data = self.comm.query.get_matching_waveforms(event["event_name"],
                                                      iteration, station)

        # Get source time function
        stf_fct = self.comm.project.get_project_function(
            "source_time_function")
        delta = self.comm.project.simulation_settings["time_step_in_s"]
        npts = self.comm.project.simulation_settings["number_of_time_steps"]
        freqmax = (
            1.0 / self.comm.project.simulation_settings["minimum_period_in_s"])
        freqmin = (
            1.0 / self.comm.project.simulation_settings["maximum_period_in_s"])
        stf_trace = stf_fct(npts=npts,
                            delta=delta,
                            freqmin=freqmin,
                            freqmax=freqmax)

        process_params = self.comm.project.simulation_settings
        minimum_period = process_params["minimum_period_in_s"]
        maximum_period = process_params["maximum_period_in_s"]

        window_group_manager = self.comm.windows.get(window_set_name)

        found_something = False
        for component in ["E", "N", "Z"]:
            try:
                data_tr = select_component_from_stream(data.data, component)
                synth_tr = select_component_from_stream(
                    data.synthetics, component)
                # delete preexisting windows
                window_group_manager.del_all_windows_from_event_channel(
                    event["event_name"], data_tr.id)
            except LASIFNotFoundError:
                continue
            found_something = True

            windows = select_windows(
                data_tr,
                synth_tr,
                stf_trace,
                event["latitude"],
                event["longitude"],
                event["depth_in_km"],
                data.coordinates["latitude"],
                data.coordinates["longitude"],
                minimum_period=minimum_period,
                maximum_period=maximum_period,
                iteration=iteration,
                **kwargs,
            )
            if not windows:
                continue

            for starttime, endtime, b_wave in windows:
                window_group_manager.add_window_to_event_channel(
                    event_name=event["event_name"],
                    channel_name=data_tr.id,
                    start_time=starttime,
                    end_time=endtime,
                )

        if found_something is False:
            raise LASIFNotFoundError(
                "No matching data found for event '%s', iteration '%s', and "
                "station '%s'." %
                (event["event_name"], iteration.name, station))