Example #1
0
    def test_loading_unregistered_function_registers(self):
        """ ensure if a function in cache hasn't been decoratored it gets
        decorated when returned """
        def func(templates, streams, pads):
            pass

        corr.XCOR_FUNCS['func_test'] = func
        corr.get_stream_xcorr('func_test')
        assert hasattr(corr.XCOR_FUNCS['func_test'], 'registered')
Example #2
0
 def test_str_accepted(self):
     """ ensure a str of the xcorr function can be passed as well """
     old_default = corr.get_array_xcorr()
     old_default_stream = corr.get_stream_xcorr()
     with corr.set_xcorr('numpy'):
         func = corr.get_array_xcorr()
         assert func is corr.numpy_normxcorr
     assert corr.get_array_xcorr() == old_default
     assert corr.get_stream_xcorr() == old_default_stream
Example #3
0
    def test_callable_registered(self, multichannel_templates,
                                 multichannel_stream):
        """ ensure a callable can be registered """
        small_count = {}

        def some_callable(template_array, stream_array, pad_array):
            small_count['name'] = 1
            return corr.numpy_normxcorr(template_array, stream_array,
                                        pad_array)

        func = corr.get_stream_xcorr(some_callable)
        func(multichannel_templates, multichannel_stream)
        assert 'name' in small_count
Example #4
0
 def test_correlation_precision(self):
     """Compare to correlation function outputs"""
     ccc, chans = _concatenate_and_correlate(streams=self.detect_streams,
                                             template=self.template,
                                             cores=1)
     fftw_xcorr_func = get_stream_xcorr("fftw")
     for _ccc, detect_stream in zip(ccc, self.detect_streams):
         fftw_ccc, _, _ = fftw_xcorr_func(templates=[self.template],
                                          stream=detect_stream,
                                          stack=False)
         for chan_ccc, fftw_chan_ccc in zip(_ccc, fftw_ccc[0]):
             self.assertTrue(
                 np.allclose(chan_ccc, fftw_chan_ccc, atol=.00001))
Example #5
0
def _concatenate_and_correlate(streams, template, cores):
    """
    Concatenate a list of streams into one stream and correlate that with a
    template.

    All traces in a stream must have the same length.
    """
    UsedChannel = namedtuple("UsedChannel", "channel used")

    samp_rate = {tr.stats.sampling_rate for st in streams for tr in st}
    assert len(samp_rate) == 1, "Multiple sample rates found"
    samp_rate = samp_rate.pop()

    template_length = {tr.stats.npts for tr in template}
    assert len(template_length) == 1, "Multiple data lengths in template"
    template_length = template_length.pop()

    channel_length = {tr.stats.npts for st in streams for tr in st}
    if len(channel_length) > 1:
        Logger.debug("Multiple lengths of stream found, using the longest")
    channel_length = sorted(list(channel_length))[-1]
    # pre-define stream for efficiency
    chans = {tr.id for st in streams for tr in st}.intersection(
        {tr.id for tr in template})
    data = np.zeros((len(chans), channel_length * len(streams)),
                    dtype=np.float32)

    # concatenate detection streams together.
    used_chans = [[] for _ in range(len(streams))]
    concatenated_stream = Stream()
    for i, chan in enumerate(chans):
        start_index = 0
        for j, stream in enumerate(streams):
            tr = stream.select(id=chan)
            if len(tr) == 0:
                # No data for this channel in this stream
                used_chans[j].append(UsedChannel(
                    channel=(chan.split('.')[1], chan.split('.')[-1]),
                    used=False))
                start_index += channel_length
                continue
            assert len(tr) == 1, "Multiple channels found for {0}".format(chan)
            data[i][start_index:start_index + tr[0].stats.npts] = tr[0].data
            start_index += channel_length
            used_chans[j].append(UsedChannel(
                channel=(chan.split('.')[1], chan.split('.')[-1]), used=True))
        net, sta, loc, chan = chan.split('.')
        concatenated_stream += Trace(
            data=data[i], header=dict(network=net, station=sta, channel=chan,
                                      location=loc, sampling_rate=samp_rate))
    # Remove unnecesary channels from template
    _template = Stream()
    for tr in template:
        if tr.id in chans:
            _template += tr
    # Do correlations
    xcorr_func = get_stream_xcorr(name_or_func="fftw")
    ccc, _, chan_order = xcorr_func(
        templates=[_template], stream=concatenated_stream, stack=False,
        cores=cores)
    # Re-order used_chans
    chan_order = chan_order[0]
    for _used_chans in used_chans:
        _used_chans.sort(key=lambda chan: chan_order.index(chan.channel))

    # Reshape ccc output
    ccc_out = np.zeros((len(streams), len(chans),
                        channel_length - template_length + 1),
                       dtype=np.float32)
    for i in range(len(streams)):
        for j, chan in enumerate(used_chans[i]):
            if not chan.used:
                continue
            index_start = i * channel_length
            index_end = index_start + channel_length - template_length + 1
            ccc_out[i][j] = ccc[0][j][index_start: index_end]
    return ccc_out, used_chans
Example #6
0
def match_filter(template_names,
                 template_list,
                 st,
                 threshold,
                 threshold_type,
                 trig_int,
                 plot=False,
                 plotdir=None,
                 xcorr_func=None,
                 concurrency=None,
                 cores=None,
                 plot_format='png',
                 output_cat=False,
                 output_event=True,
                 extract_detections=False,
                 arg_check=True,
                 full_peaks=False,
                 peak_cores=None,
                 spike_test=True,
                 **kwargs):
    """
    Main matched-filter detection function.

    Over-arching code to run the correlations of given templates with a
    day of seismic data and output the detections based on a given threshold.
    For a functional example see the tutorials.

    :type template_names: list
    :param template_names:
        List of template names in the same order as template_list
    :type template_list: list
    :param template_list:
        A list of templates of which each template is a
        :class:`obspy.core.stream.Stream` of obspy traces containing seismic
        data and header information.
    :type st: :class:`obspy.core.stream.Stream`
    :param st:
        A Stream object containing all the data available and
        required for the correlations with templates given.  For efficiency
        this should contain no excess traces which are not in one or more of
        the templates.  This will now remove excess traces internally, but
        will copy the stream and work on the copy, leaving your input stream
        untouched.
    :type threshold: float
    :param threshold: A threshold value set based on the threshold_type
    :type threshold_type: str
    :param threshold_type:
        The type of threshold to be used, can be MAD, absolute or av_chan_corr.
        See Note on thresholding below.
    :type trig_int: float
    :param trig_int:
        Minimum gap between detections from one template in seconds.
        If multiple detections occur within trig_int of one-another, the one
        with the highest cross-correlation sum will be selected.
    :type plot: bool
    :param plot: Turn plotting on or off
    :type plotdir: str
    :param plotdir:
        Path to plotting folder, plots will be output here, defaults to None,
        and plots are shown on screen.
    :type xcorr_func: str or callable
    :param xcorr_func:
        A str of a registered xcorr function or a callable for implementing
        a custom xcorr function. For more information see:
        :func:`eqcorrscan.utils.correlate.register_array_xcorr`
    :type concurrency: str
    :param concurrency:
        The type of concurrency to apply to the xcorr function. Options are
        'multithread', 'multiprocess', 'concurrent'. For more details see
        :func:`eqcorrscan.utils.correlate.get_stream_xcorr`
    :type cores: int
    :param cores: Number of cores to use
    :type plot_format: str
    :param plot_format: Specify format of output plots if saved
    :type output_cat: bool
    :param output_cat:
        Specifies if matched_filter will output an obspy.Catalog class
        containing events for each detection. Default is False, in which case
        matched_filter will output a list of detection classes, as normal.
    :type output_event: bool
    :param output_event:
        Whether to include events in the Detection objects, defaults to True,
        but for large cases you may want to turn this off as Event objects
        can be quite memory intensive.
    :type extract_detections: bool
    :param extract_detections:
        Specifies whether or not to return a list of streams, one stream per
        detection.
    :type arg_check: bool
    :param arg_check:
        Check arguments, defaults to True, but if running in bulk, and you are
        certain of your arguments, then set to False.
    :type full_peaks: bool
    :param full_peaks: See
        :func: `eqcorrscan.utils.findpeaks.find_peaks_compiled`
    :type peak_cores: int
    :param peak_cores:
        Number of processes to use for parallel peak-finding (if different to
        `cores`).
    :type spike_test: bool
    :param spike_test: If set True, raise error when there is a spike in data.
        defaults to True.

    .. Note::
        When using the "fftw" correlation backend the length of the fft
        can be set. See :mod:`eqcorrscan.utils.correlate` for more info.

    .. note::
        **Returns:**

        If neither `output_cat` or `extract_detections` are set to `True`,
        then only the list of :class:`eqcorrscan.core.match_filter.Detection`'s
        will be output:

        :return:
            :class:`eqcorrscan.core.match_filter.Detection` detections for each
            detection made.
        :rtype: list

        If `output_cat` is set to `True`, then the
        :class:`obspy.core.event.Catalog` will also be output:

        :return: Catalog containing events for each detection, see above.
        :rtype: :class:`obspy.core.event.Catalog`

        If `extract_detections` is set to `True` then the list of
        :class:`obspy.core.stream.Stream`'s will also be output.

        :return:
            list of :class:`obspy.core.stream.Stream`'s for each detection, see
            above.
        :rtype: list

    .. note::
        If your data contain gaps these must be padded with zeros before
        using this function. The `eqcorrscan.utils.pre_processing` functions
        will provide gap-filled data in the appropriate format.  Note that if
        you pad your data with zeros before filtering or resampling the gaps
        will not be all zeros after filtering. This will result in the
        calculation of spurious correlations in the gaps.

    .. Note::
        Detections are not corrected for `pre-pick`, the
        detection.detect_time corresponds to the beginning of the earliest
        template channel at detection.

    .. note::
        **Data overlap:**

        Internally this routine shifts and trims the data according to the
        offsets in the template (e.g. if trace 2 starts 2 seconds after trace 1
        in the template then the continuous data will be shifted by 2 seconds
        to align peak correlations prior to summing).  Because of this,
        detections at the start and end of continuous data streams
        **may be missed**.  The maximum time-period that might be missing
        detections is the maximum offset in the template.

        To work around this, if you are conducting matched-filter detections
        through long-duration continuous data, we suggest using some overlap
        (a few seconds, on the order of the maximum offset in the templates)
        in the continous data.  You will then need to post-process the
        detections (which should be done anyway to remove duplicates).

    .. note::
        **Thresholding:**

        **MAD** threshold is calculated as the:

        .. math::

            threshold {\\times} (median(abs(cccsum)))

        where :math:`cccsum` is the cross-correlation sum for a given template.

        **absolute** threshold is a true absolute threshold based on the
        cccsum value.

        **av_chan_corr** is based on the mean values of single-channel
        cross-correlations assuming all data are present as required for the
        template, e.g:

        .. math::

            av\_chan\_corr\_thresh=threshold \\times (cccsum\ /\ len(template))

        where :math:`template` is a single template from the input and the
        length is the number of channels within this template.

    .. note::
        The output_cat flag will create an :class:`obspy.core.event.Catalog`
        containing one event for each
        :class:`eqcorrscan.core.match_filter.Detection`'s generated by
        match_filter. Each event will contain a number of comments dealing
        with correlation values and channels used for the detection. Each
        channel used for the detection will have a corresponding
        :class:`obspy.core.event.Pick` which will contain time and
        waveform information. **HOWEVER**, the user should note that
        the pick times do not account for the prepick times inherent in
        each template. For example, if a template trace starts 0.1 seconds
        before the actual arrival of that phase, then the pick time generated
        by match_filter for that phase will be 0.1 seconds early.

    .. Note::
        xcorr_func can be used as follows:

        .. rubric::xcorr_func argument example

        >>> import obspy
        >>> import numpy as np
        >>> from eqcorrscan.core.match_filter.matched_filter import (
        ...    match_filter)
        >>> from eqcorrscan.utils.correlate import time_multi_normxcorr
        >>> # define a custom xcorr function
        >>> def custom_normxcorr(templates, stream, pads, *args, **kwargs):
        ...     # Just to keep example short call other xcorr function
        ...     # in practice you would define your own function here
        ...     print('calling custom xcorr function')
        ...     return time_multi_normxcorr(templates, stream, pads)
        >>> # generate some toy templates and stream
        >>> random = np.random.RandomState(42)
        >>> template = obspy.read()
        >>> stream = obspy.read()
        >>> for num, tr in enumerate(stream):  # iter st and embed templates
        ...     data = tr.data
        ...     tr.data = random.randn(6000) * 5
        ...     tr.data[100: 100 + len(data)] = data
        >>> # call match_filter ane ensure the custom function is used
        >>> detections = match_filter(
        ...     template_names=['1'], template_list=[template], st=stream,
        ...     threshold=.5, threshold_type='absolute', trig_int=1,
        ...     plotvar=False,
        ...     xcorr_func=custom_normxcorr)  # doctest:+ELLIPSIS
        calling custom xcorr function...
    """
    from eqcorrscan.core.match_filter.detection import Detection
    from eqcorrscan.utils.plotting import _match_filter_plot

    if "plotvar" in kwargs.keys():
        Logger.warning("plotvar is depreciated, use plot instead")
        plot = kwargs.get("plotvar")

    if arg_check:
        # Check the arguments to be nice - if arguments wrong type the parallel
        # output for the error won't be useful
        if not isinstance(template_names, list):
            raise MatchFilterError('template_names must be of type: list')
        if not isinstance(template_list, list):
            raise MatchFilterError('templates must be of type: list')
        if not len(template_list) == len(template_names):
            raise MatchFilterError('Not the same number of templates as names')
        for template in template_list:
            if not isinstance(template, Stream):
                msg = 'template in template_list must be of type: ' + \
                      'obspy.core.stream.Stream'
                raise MatchFilterError(msg)
        if not isinstance(st, Stream):
            msg = 'st must be of type: obspy.core.stream.Stream'
            raise MatchFilterError(msg)
        if str(threshold_type) not in [
                str('MAD'), str('absolute'),
                str('av_chan_corr')
        ]:
            msg = 'threshold_type must be one of: MAD, absolute, av_chan_corr'
            raise MatchFilterError(msg)
        for tr in st:
            if not tr.stats.sampling_rate == st[0].stats.sampling_rate:
                raise MatchFilterError(
                    'Sampling rates are not equal %f: %f' %
                    (tr.stats.sampling_rate, st[0].stats.sampling_rate))
        for template in template_list:
            for tr in template:
                if not tr.stats.sampling_rate == st[0].stats.sampling_rate:
                    raise MatchFilterError('Template sampling rate does not '
                                           'match continuous data')
        for template in template_list:
            for tr in template:
                if isinstance(tr.data, np.ma.core.MaskedArray):
                    raise MatchFilterError(
                        'Template contains masked array, split first')
    if spike_test:
        Logger.info("Checking for spikes in data")
        _spike_test(st)
    if cores is not None:
        parallel = True
    else:
        parallel = False
    if peak_cores is None:
        peak_cores = cores
    # Copy the stream here because we will muck about with it
    Logger.info("Copying data to keep your input safe")
    stream = st.copy()
    templates = [t.copy() for t in template_list]
    _template_names = template_names.copy()  # This can just be a shallow copy

    Logger.info("Reshaping templates")
    stream, templates, _template_names = _prep_data_for_correlation(
        stream=stream, templates=templates, template_names=_template_names)
    if len(templates) == 0:
        raise IndexError("No matching data")
    Logger.info('Starting the correlation run for these data')
    for template in templates:
        Logger.debug(template.__str__())
    Logger.debug(stream.__str__())
    multichannel_normxcorr = get_stream_xcorr(xcorr_func, concurrency)
    outtic = default_timer()
    [cccsums, no_chans, chans] = multichannel_normxcorr(templates=templates,
                                                        stream=stream,
                                                        cores=cores,
                                                        **kwargs)
    if len(cccsums[0]) == 0:
        raise MatchFilterError('Correlation has not run, zero length cccsum')
    outtoc = default_timer()
    Logger.info(
        'Looping over templates and streams took: {0:.4f}s'.format(outtoc -
                                                                   outtic))
    Logger.debug('The shape of the returned cccsums is: {0}'.format(
        cccsums.shape))
    Logger.debug('This is from {0} templates correlated with {1} channels of '
                 'data'.format(len(templates), len(stream)))
    detections = []
    if output_cat:
        det_cat = Catalog()
    if str(threshold_type) == str("absolute"):
        thresholds = [threshold for _ in range(len(cccsums))]
    elif str(threshold_type) == str('MAD'):
        thresholds = [
            threshold * np.median(np.abs(cccsum)) for cccsum in cccsums
        ]
    else:
        thresholds = [threshold * no_chans[i] for i in range(len(cccsums))]
    if peak_cores is None:
        peak_cores = cores
    outtic = default_timer()
    all_peaks = multi_find_peaks(arr=cccsums,
                                 thresh=thresholds,
                                 parallel=parallel,
                                 trig_int=int(trig_int *
                                              stream[0].stats.sampling_rate),
                                 full_peaks=full_peaks,
                                 cores=peak_cores)
    outtoc = default_timer()
    Logger.info("Finding peaks took {0:.4f}s".format(outtoc - outtic))
    for i, cccsum in enumerate(cccsums):
        if np.abs(np.mean(cccsum)) > 0.05:
            Logger.warning('Mean is not zero!  Check this!')
        # Set up a trace object for the cccsum as this is easier to plot and
        # maintains timing
        if plot:
            _match_filter_plot(stream=stream,
                               cccsum=cccsum,
                               template_names=_template_names,
                               rawthresh=thresholds[i],
                               plotdir=plotdir,
                               plot_format=plot_format,
                               i=i)
        if all_peaks[i]:
            Logger.debug("Found {0} peaks for template {1}".format(
                len(all_peaks[i]), _template_names[i]))
            for peak in all_peaks[i]:
                detecttime = (stream[0].stats.starttime +
                              peak[1] / stream[0].stats.sampling_rate)
                detection = Detection(template_name=_template_names[i],
                                      detect_time=detecttime,
                                      no_chans=no_chans[i],
                                      detect_val=peak[0],
                                      threshold=thresholds[i],
                                      typeofdet='corr',
                                      chans=chans[i],
                                      threshold_type=threshold_type,
                                      threshold_input=threshold)
                if output_cat or output_event:
                    detection._calculate_event(template_st=templates[i])
                detections.append(detection)
                if output_cat:
                    det_cat.append(detection.event)
        else:
            Logger.debug("Found 0 peaks for template {0}".format(
                _template_names[i]))
    Logger.info("Made {0} detections from {1} templates".format(
        len(detections), len(templates)))
    if extract_detections:
        detection_streams = extract_from_stream(stream, detections)
    del stream, templates

    if output_cat and not extract_detections:
        return detections, det_cat
    elif not extract_detections:
        return detections
    elif extract_detections and not output_cat:
        return detections, detection_streams
    else:
        return detections, det_cat, detection_streams
Example #7
0
 def test_bad_concurrency_raises(self):
     """ ensure passing an invalid concurrency argument raises a
     ValueError"""
     with pytest.raises(ValueError):
         corr.get_stream_xcorr(concurrency='node.js')
Example #8
0
 def test_noargs_returns_default(self):
     """ ensure passing no args to get_stream_xcorr returns default """
     func = corr.get_stream_xcorr()
     default = corr.XCOR_FUNCS['default'].stream_xcorr
     assert func is default
Example #9
0

@pytest.fixture(scope='module')
def multichannel_stream():
    """ create a multichannel stream for tests """
    return generate_multichannel_stream()


@pytest.fixture(scope='module')
def gappy_multichannel_stream():
    """ Create a multichannel stream with gaps (padded with zeros). """
    return generate_gappy_multichannel_stream()


# a dict of all registered stream functions (this is a bit long)
stream_funcs = {fname + '_' + mname: corr.get_stream_xcorr(fname, mname)
                for fname in corr.XCORR_FUNCS_ORIGINAL.keys()
                for mname in corr.XCORR_STREAM_METHODS
                if fname != 'default'}


@pytest.fixture(scope='module')
def stream_cc_output_dict(multichannel_templates, multichannel_stream):
    """ return a dict of outputs from all stream_xcorr functions """
    # corr._get_array_dicts(multichannel_templates, multichannel_stream)
    out = {}
    for name, func in stream_funcs.items():
        cc_out = time_func(func, name, multichannel_templates,
                           multichannel_stream, cores=1)
        out[name] = cc_out
    return out
Example #10
0
def cross_chan_correlation(st1,
                           streams,
                           shift_len=0.0,
                           allow_individual_trace_shifts=True,
                           xcorr_func='fftw',
                           concurrency="concurrent",
                           cores=1,
                           **kwargs):
    """
    Calculate cross-channel correlation.

    Determine the cross-channel correlation between two streams of
    multichannel seismic data.

    :type st1: obspy.core.stream.Stream
    :param st1: Stream one
    :type streams: list
    :param streams: Streams to compare to.
    :type shift_len: float
    :param shift_len: How many seconds for templates to shift
    :type allow_individual_trace_shifts: bool
    :param allow_individual_trace_shifts:
        Controls whether templates are shifted by shift_len in relation to the
        picks as a whole, or whether each trace can be shifted individually.
        Defaults to True.
    :type xcorr_func: str, callable
    :param xcorr_func:
        The method for performing correlations. Accepts either a string or
        callable. See :func:`eqcorrscan.utils.correlate.register_array_xcorr`
        for more details
    :type concurrency: str
    :param concurrency: Concurrency for xcorr-func.
    :type cores: int
    :param cores: Number of threads to parallel over

    :returns:
        cross channel correlation, float - normalized by number of channels.
        locations of maximums
    :rtype: numpy.ndarray, numpy.ndarray

    .. Note::
        If no matching channels were found then the coherance and index for
        that stream will be nan.
    """
    # Cut all channels in stream-list to be the correct length (shorter than
    # st1 if stack = False by shift_len).
    allow_individual_trace_shifts = (allow_individual_trace_shifts
                                     and shift_len > 0)
    n_streams = len(streams)
    df = st1[0].stats.sampling_rate
    end_trim = int((shift_len * df) / 2)
    _streams = []
    if end_trim > 0:
        for stream in streams:
            _stream = stream.copy()  # Do not work on the users data
            for tr in _stream:
                tr.data = tr.data[end_trim:-end_trim]
                if tr.stats.sampling_rate != df:
                    raise NotImplementedError("Sampling rates differ")
            _streams.append(_stream)
        streams = _streams
    else:
        # _prep_data_for_correlation works in place on data.
        # We need to copy it first.
        streams = [stream.copy() for stream in streams]
    # Check which channels are in st1 and match those in the stream_list
    st1_preped, prep_streams, stream_indexes = _prep_data_for_correlation(
        stream=st1.copy(),
        templates=streams,
        template_names=list(range(len(streams))),
        force_stream_epoch=False)
    # Run the correlations
    multichannel_normxcorr = get_stream_xcorr(xcorr_func, concurrency)
    [cccsums, no_chans, _] = multichannel_normxcorr(templates=prep_streams,
                                                    stream=st1_preped,
                                                    cores=cores,
                                                    stack=False,
                                                    **kwargs)
    # Find maximas, sum and divide by no_chans
    if allow_individual_trace_shifts:
        coherances = cccsums.max(axis=-1).sum(axis=-1) / no_chans
    else:
        cccsums = cccsums.sum(axis=1)
        coherances = cccsums.max(axis=-1) / no_chans
    # Subtract half length of correlogram and convert positions to seconds
    positions = (cccsums.argmax(axis=-1) - end_trim) / df

    # This section re-orders the coherences to correspond to the order of the
    # input streams
    _coherances = np.empty(n_streams)
    if allow_individual_trace_shifts:
        n_max_traces = max([len(st) for st in prep_streams])
        # Set shifts for nan-traces to nan
        for i, tr in enumerate(st1_preped):
            if np.ma.is_masked(tr.data):
                positions[:, i] = np.nan
    else:
        positions = positions[:, np.newaxis]
        n_max_traces = 1
    n_shifts_per_stream = positions.shape[1]
    _positions = np.empty([n_streams, n_max_traces])

    _coherances.fill(np.nan)
    _positions.fill(np.nan)
    # Insert the correlations and shifts at the correct index for the templates
    _coherances[np.ix_(stream_indexes)] = coherances
    _positions[np.ix_(stream_indexes,
                      range(n_shifts_per_stream))] = (positions)

    if not allow_individual_trace_shifts:  # remove empty third axis from array
        _positions = _positions[:, ]
    return _coherances, _positions
Example #11
0
def cross_chan_correlation(st1, streams, shift_len=0.0, xcorr_func='fftw',
                           concurrency="concurrent", cores=1, **kwargs):
    """
    Calculate cross-channel correlation.

    Determine the cross-channel correlation between two streams of
    multichannel seismic data.

    :type st1: obspy.core.stream.Stream
    :param st1: Stream one
    :type streams: list
    :param streams: Streams to compare to.
    :type shift_len: float
    :param shift_len: Seconds to shift, only used if `allow_shift=True`
    :type xcorr_func: str, callable
    :param xcorr_func:
        The method for performing correlations. Accepts either a string or
        callable. See :func:`eqcorrscan.utils.correlate.register_array_xcorr`
        for more details
    :type concurrency: str
    :param concurrency: Concurrency for xcorr-func.
    :type cores: int
    :param cores: Number of threads to parallel over

    :returns:
        cross channel correlation, float - normalized by number of channels.
        locations of maximums
    :rtype: numpy.ndarray, numpy.ndarray

    .. Note::
        If no matching channels were found then the coherance and index for
        that stream will be nan.
    """
    # Cut all channels in stream-list to be the correct length (shorter than
    # st1 if stack = False by shift_len).
    n_streams = len(streams)
    df = st1[0].stats.sampling_rate
    end_trim = int((shift_len * df) / 2)
    _streams = []
    if end_trim > 0:
        for stream in streams:
            _stream = stream.copy()  # Do not work on the users data
            for tr in _stream:
                tr.data = tr.data[end_trim: -end_trim]
                if tr.stats.sampling_rate != df:
                    raise NotImplementedError("Sampling rates differ")
            _streams.append(_stream)
        streams = _streams
    # Check which channels are in st1 and match those in the stream_list
    st1, streams, stream_indexes = _prep_data_for_correlation(
        stream=st1, templates=streams,
        template_names=list(range(len(streams))), force_stream_epoch=False)
    # Run the correlations
    multichannel_normxcorr = get_stream_xcorr(xcorr_func, concurrency)
    [cccsums, no_chans, _] = multichannel_normxcorr(
        templates=streams, stream=st1, cores=cores, stack=False, **kwargs)
    # Find maximas, sum and divide by no_chans
    coherances = cccsums.max(axis=-1).sum(axis=-1) / no_chans
    positions = cccsums.argmax(axis=-1)
    # positions should probably have half the length of the correlogram
    # subtracted, and possibly be converted to seconds?
    _coherances = np.empty(n_streams)
    _positions = np.empty((n_streams, no_chans.max()))
    _coherances.fill(np.nan)
    _positions.fill(np.nan)
    for coh_ind, stream_ind in enumerate(stream_indexes):
        _coherances[stream_ind] = coherances[coh_ind]
        _positions[stream_ind] = positions[coh_ind]
    return _coherances, _positions
def test_multi_channel_xcorr():
    chans = ['EHZ', 'EHN', 'EHE']
    stas = ['COVA', 'FOZ', 'LARB', 'GOVA', 'MTFO', 'MTBA']
    n_templates = 20
    stream_len = 10000
    template_len = 200
    templates = []
    stream = Stream()
    for station in stas:
        for channel in chans:
            stream += Trace(data=np.random.randn(stream_len))
            stream[-1].stats.channel = channel
            stream[-1].stats.station = station
    for i in range(n_templates):
        template = Stream()
        for station in stas:
            for channel in chans:
                template += Trace(data=np.random.randn(template_len))
                template[-1].stats.channel = channel
                template[-1].stats.station = station
        templates.append(template)
    print("Running time serial")
    tic = time.time()
    multichannel_normxcorr = get_stream_xcorr("time_domain", concurrency=None)
    cccsums_t_s, no_chans, chans = multichannel_normxcorr(templates=templates,
                                                          stream=stream)
    toc = time.time()
    print('Time-domain in serial took: %f seconds' % (toc - tic))
    print("Running time parallel")
    tic = time.time()
    multichannel_normxcorr = get_stream_xcorr("time_domain",
                                              concurrency="multiprocess")
    cccsums_t_p, no_chans, chans = multichannel_normxcorr(templates=templates,
                                                          stream=stream,
                                                          cores=4)
    toc = time.time()
    print('Time-domain in parallel took: %f seconds' % (toc - tic))
    print("Running frequency serial")
    tic = time.time()
    multichannel_normxcorr = get_stream_xcorr("fftw", concurrency=None)
    cccsums_f_s, no_chans, chans = multichannel_normxcorr(templates=templates,
                                                          stream=stream)
    toc = time.time()
    print('Frequency-domain in serial took: %f seconds' % (toc - tic))
    print("Running frequency parallel")
    tic = time.time()
    multichannel_normxcorr = get_stream_xcorr("fftw",
                                              concurrency="multiprocess")
    cccsums_f_p, no_chans, chans = multichannel_normxcorr(templates=templates,
                                                          stream=stream,
                                                          cores=4)
    toc = time.time()
    print('Frequency-domain in parallel took: %f seconds' % (toc - tic))
    print("Running frequency openmp parallel")
    tic = time.time()
    multichannel_normxcorr = get_stream_xcorr("fftw", concurrency="concurrent")
    cccsums_f_op, no_chans, chans = multichannel_normxcorr(templates=templates,
                                                           stream=stream)
    toc = time.time()
    print('Frequency-domain in parallel took: %f seconds' % (toc - tic))
    print("Finished")
    assert (np.allclose(cccsums_t_s, cccsums_t_p, atol=0.00001))
    assert (np.allclose(cccsums_f_s, cccsums_f_p, atol=0.00001))
    assert (np.allclose(cccsums_f_s, cccsums_f_op, atol=0.00001))
    assert (np.allclose(cccsums_t_p, cccsums_f_s, atol=0.001))