Esempio n. 1
0
def test_core():
    recs_per_chunk = 10
    n_chunks = 10

    class Records(strax.Plugin):
        provides = 'records'
        depends_on = tuple()
        dtype = strax.record_dtype()

        def iter(self, *args, **kwargs):
            for t in range(n_chunks):
                r = np.zeros(recs_per_chunk, self.dtype)
                r['time'] = t
                r['length'] = 1
                r['dt'] = 1
                r['channel'] = np.arange(len(r))
                yield r

    class Peaks(strax.Plugin):
        provides = 'peaks'
        depends_on = ('records', )
        dtype = strax.peak_dtype()

        def compute(self, records):
            p = np.zeros(len(records), self.dtype)
            p['time'] = records['time']
            return p

    mystrax = strax.Context(storage=[])
    mystrax.register(Records)
    mystrax.register(Peaks)

    bla = mystrax.get_array(run_id='some_run', targets='peaks')
    assert len(bla) == recs_per_chunk * n_chunks
    assert bla.dtype == strax.peak_dtype()
Esempio n. 2
0
class Peaks(strax.Plugin):
    """
    Stolen from straxen, extended marginally
    """
    depends_on = ('records',)
    data_kind = 'peaks'
    parallel = True
    rechunk_on_save = True
    dtype = strax.peak_dtype(n_channels=len(to_pe))

    def compute(self, records):
        r = records
        hits = strax.find_hits(r)       # TODO: Duplicate work
        hits = strax.sort_by_time(hits)

        peaks = strax.find_peaks(hits, to_pe,
                                 result_dtype=self.dtype)
        strax.sum_waveform(peaks, r, to_pe)

        peaks = strax.split_peaks(peaks, r, to_pe)

        strax.compute_widths(peaks)

        if self.config['diagnose_sorting']:
            assert np.diff(r['time']).min() >= 0, "Records not sorted"
            assert np.diff(hits['time']).min() >= 0, "Hits not sorted"
            assert np.all(peaks['time'][1:]
                          >= strax.endtime(peaks)[:-1]), "Peaks not disjoint"

        return peaks
Esempio n. 3
0
def test_sum_waveform(records, peak_left, peak_length):
    # Make a single big peak to contain all the records
    n_ch = 100
    peaks = np.zeros(1, strax.peak_dtype(n_ch, n_sum_wv_samples=200))
    p = peaks[0]
    p['time'] = peak_left
    p['length'] = peak_length
    p['dt'] = 0

    strax.sum_waveform(peaks, records, np.ones(n_ch))

    # Area measures must be consistent
    area = p['area']
    assert area >= 0
    assert p['data'].sum() == area
    assert p['area_per_channel'].sum() == area

    # Create a simple sum waveform
    if not len(records):
        max_sample = 3   # Whatever
    else:
        max_sample = (records['time'] + records['length']).max()
    max_sample = max(max_sample, peak_left + peak_length)
    sum_wv = np.zeros(max_sample + 1, dtype=np.float32)
    for r in records:
        sum_wv[r['time']:r['time'] + r['length']] += r['data'][:r['length']]
    # Select the part inside the peak
    sum_wv = sum_wv[peak_left:peak_left + peak_length]

    assert len(sum_wv) == peak_length
    assert np.all(p['data'][:peak_length] == sum_wv)
Esempio n. 4
0
    class Peaks(strax.Plugin):
        provides = 'peaks'
        depends_on = ('records', )
        dtype = strax.peak_dtype()

        def compute(self, records):
            p = np.zeros(len(records), self.dtype)
            p['time'] = records['time']
            return p
Esempio n. 5
0
class FunnyPeaks(strax.Plugin):
    parallel = True
    provides = 'peaks'
    depends_on = 'even_recs'
    dtype = strax.peak_dtype()

    def compute(self, even_recs):
        p = np.zeros(len(even_recs), self.dtype)
        p['time'] = even_recs['time']
        return p
Esempio n. 6
0
class Peaks(strax.Plugin):
    provides = 'peaks'
    depends_on = ('records', )
    dtype = strax.peak_dtype()

    def compute(self, records):
        if self.config['give_wrong_dtype']:
            return np.zeros(5, [('a', np.int), ('b', np.float)])
        p = np.zeros(len(records), self.dtype)
        p['time'] = records['time']
        return p
Esempio n. 7
0
def test_core():
    for max_workers in [1, 2]:
        mystrax = strax.Context(
            storage=[],
            register=[Records, Peaks],
        )
        bla = mystrax.get_array(run_id=run_id,
                                targets='peaks',
                                max_workers=max_workers)
        assert len(bla) == recs_per_chunk * n_chunks
        assert bla.dtype == strax.peak_dtype()
class Peaks(strax.Plugin):
    parallel = True
    provides = 'peaks'
    depends_on = ('records', )
    dtype = strax.peak_dtype()

    def compute(self, records):
        assert isinstance(records, np.ndarray), \
            f"Recieved {type(records)} instead of numpy array!"
        p = np.zeros(len(records), self.dtype)
        p['time'] = records['time']
        return p
Esempio n. 9
0
def test_splitter_outer():
    data = [0, 2, 2, 0, 2, 2, 1]
    records = np.zeros(1, dtype=strax.record_dtype(len(data)))
    records['dt'] = 1
    records['data'] = data
    records['length'] = len(data)
    records['pulse_length'] = len(data)
    to_pe = np.ones(10)

    hits = strax.find_hits(records, np.ones(1))
    hits['left_integration'] = hits['left']
    hits['right_integration'] = hits['right']
    peaks = np.zeros(1, dtype=strax.peak_dtype())
    hitlets = np.zeros(1, dtype=strax.hitlet_with_data_dtype(10))
    for data_type in (peaks, hitlets):
        data_type['dt'] = 1
        data_type['data'][0, :len(data)] = data
        data_type['length'] = len(data)

    rlinks = strax.record_links(records)
    peaks = strax.split_peaks(peaks,
                              hits,
                              records,
                              rlinks,
                              to_pe,
                              algorithm='local_minimum',
                              data_type='peaks',
                              min_height=1,
                              min_ratio=0)

    hitlets = strax.split_peaks(hitlets,
                                hits,
                                records,
                                rlinks,
                                to_pe,
                                algorithm='local_minimum',
                                data_type='hitlets',
                                min_height=1,
                                min_ratio=0)

    for name, data_type in zip(('peaks', 'hitlets'), (peaks, hitlets)):
        data = data_type[0]['data'][:data_type[0]['length']]
        assert np.all(
            data == [0, 2, 2]
        ), f'Wrong split for {name}, got {data}, expected {[0, 2, 2]}.'
        data = data_type[1]['data'][:data_type[1]['length']]
        assert np.all(
            data == [0, 2, 2, 1]
        ), f'Wrong split for {name}, got {data}, expected {[0, 2, 2, 1]}.'
Esempio n. 10
0
class Peaks(strax.Plugin):
    provides = 'peaks'
    data_kind = 'peaks'
    depends_on = ('records', )
    dtype = strax.peak_dtype()
    parallel = True

    def compute(self, records):
        if self.config['give_wrong_dtype']:
            return np.zeros(5, [('a', np.int), ('b', np.float)])
        p = np.zeros(len(records), self.dtype)
        p['time'] = records['time']
        p['length'] = p['dt'] = 1
        p['area'] = self.config['base_area'] + self.config['bonus_area']
        return p
def test_processing():
    """Test ParallelSource plugin under several conditions"""
    # It's always harder with a small mailbox:
    strax.Mailbox.DEFAULT_MAX_MESSAGES = 2
    for request_peaks in (True, False):
        for peaks_parallel in (True, False):
            for max_workers in (1, 2):
                Peaks.parallel = peaks_parallel
                print(f"\nTesting with request_peaks {request_peaks}, "
                      f"peaks_parallel {peaks_parallel}, "
                      f"max_workers {max_workers}")

                mystrax = strax.Context(storage=[], register=[Records, Peaks])
                bla = mystrax.get_array(
                    run_id=run_id,
                    targets='peaks' if request_peaks else 'records',
                    max_workers=max_workers)
                assert len(bla) == recs_per_chunk * n_chunks
                assert bla.dtype == (strax.peak_dtype() if request_peaks else
                                     strax.record_dtype())
Esempio n. 12
0
 def infer_dtype(self):
     return strax.peak_dtype(n_channels=self.config['n_he_pmts'])
Esempio n. 13
0
 def infer_dtype(self):
     return dict(
         peaklets=strax.peak_dtype(n_channels=self.config['n_tpc_pmts']),
         lone_hits=strax.hit_dtype)
Esempio n. 14
0
def test_peak_overflow(
    records,
    gap_factor,
    right_extension,
    gap_threshold,
    max_duration,
):
    """
    Test that we handle dt overflows in peaks correctly. To this end, we
        just create some sets of records and copy that set of records
        for a few times. That way we may end up with a very long
        artificial set of hits that can be used in the peak building. By
        setting the peak finding parameters to very strange conditions
        we are able to replicate the behaviour where a peak would become
        so large that it cannot be written out correctly due to integer
        overflow of the dt field,
    :param records: records
    :param gap_factor: to create very extended sets of records, just
        add a factor that can be used to multiply the time field with,
        to more quickly arrive to a very long pulse-train
    :param max_duration: max_duration option for strax.find_peaks
    :param right_extension: option for strax.find_peaks
    :param gap_threshold: option for strax.find_peaks
    :return: None
    """

    # Set this here, no need to test left and right independently
    left_extension = 0
    # Make a single big peak to contain all the records
    peak_dtype = np.zeros(0, strax.peak_dtype()).dtype
    # NB! This is only for before #403, now peaks are int32 so
    # this test would take forever with int32.
    magic_overflow_time = np.iinfo(np.int16).max * peak_dtype['data'].shape[0]

    def retrun_1(x):
        """
        Return 1 for all of the input that can be used as a parameter
            for the splitting in natural breaks
        :param x: any type of array
        :return: ones * len(array)
        """
        ret = np.ones(len(x))
        return ret

    r = records
    if not len(r) or len(r['channel']) == 1:
        # Hard to test integer overflow for empty records or with
        # records only from a single channel
        return

    # Copy the pulse train of the records. We are going to copy the same
    # set of records many times now.
    t_max = strax.endtime(r).max()
    print('make buffer')
    n_repeat = int(1.5 * magic_overflow_time + t_max * gap_factor) // int(
        t_max * gap_factor) + 1
    time_offset = np.linspace(0,
                              1.5 * magic_overflow_time + t_max * gap_factor,
                              n_repeat,
                              dtype=np.int64)
    r_buffer = np.tile(r, n_repeat // len(r) + 1)[:len(time_offset)]
    assert len(r_buffer) == len(time_offset)
    r_buffer['time'] = r_buffer['time'] + time_offset
    assert strax.endtime(
        r_buffer[-1]) - r_buffer['time'].min() > magic_overflow_time
    r = r_buffer.copy()
    del r_buffer
    print(f'Array is {r.nbytes/1e6} MB, good luck')

    # Do peak finding!
    print(f'Find hits')
    hits = strax.find_hits(r, min_amplitude=0)
    assert len(hits)
    hits = strax.sort_by_time(hits)

    # Dummy to_pe
    to_pe = np.ones(max(r['channel']) + 1)

    try:
        print('Find peaks')
        # Find peaks, we might end up with negative dt here!
        p = strax.find_peaks(
            hits,
            to_pe,
            gap_threshold=gap_threshold,
            left_extension=left_extension,
            right_extension=right_extension,
            max_duration=max_duration,
            # Due to these settings, we will start merging
            # whatever strax can get its hands on
            min_area=0.,
            min_channels=1,
        )
    except AssertionError as e:
        if not gap_threshold > left_extension + right_extension:
            print(f'Great, we are getting the assertion statement for the '
                  f'incongruent extensions')
            return
        elif not left_extension + max_duration + right_extension < magic_overflow_time:
            # Ending up here is the ultimate goal of the tests. This
            # means we are hitting github.com/AxFoundation/strax/issues/397
            print(f'Great, the test worked, we are getting the assertion '
                  f'statement for the int overflow')
            return
        else:
            # The error is caused by something else, we need to re-raise
            raise e

    print(f'Peaklet array is {p.nbytes / 1e6} MB, good luck')
    if len(p) == 0:
        print(f'rec length {len(r)}')
    assert len(p)
    assert np.all(p['dt'] > 0)

    # Double check that this error should have been raised.
    if not gap_threshold > left_extension + right_extension:
        raise ValueError(f'No assertion error raised! Working with'
                         f'{gap_threshold} {left_extension + right_extension}')

    # Compute basics
    hits = strax.find_hits(r, np.ones(10000))
    hits['left_integration'] = hits['left']
    hits['right_integration'] = hits['right']
    rlinks = strax.record_links(r)
    strax.sum_waveform(p, hits, r, rlinks, to_pe)
    strax.compute_widths(p)

    try:
        print('Split peaks')
        peaklets = strax.split_peaks(p,
                                     hits,
                                     r,
                                     rlinks,
                                     to_pe,
                                     algorithm='natural_breaks',
                                     threshold=retrun_1,
                                     split_low=True,
                                     filter_wing_width=70,
                                     min_area=0,
                                     do_iterations=2)
    except AssertionError as e:
        if not left_extension + max_duration + right_extension < magic_overflow_time:
            # Ending up here is the ultimate goal of the tests. This
            # means we are hitting github.com/AxFoundation/strax/issues/397
            print(f'Great, the test worked, we are getting the assertion '
                  f'statement for the int overflow')
            raise RuntimeError(
                'We were not properly warned of the imminent peril we are '
                'facing. This error means that the peak_finding is not '
                'protected against integer overflow in the dt field. Where is '
                'our white knight in shining armour to protected from this '
                'imminent doom:\n'
                'github.com/AxFoundation/strax/issues/397') from e
        # We failed for another reason, we need to re-raise
        raise e

    assert len(peaklets)
    assert len(peaklets) <= len(r)
    # Integer overflow will manifest itself here again:
    assert np.all(peaklets['dt'] > 0)
Esempio n. 15
0
class PeakSplitter:
    find_split_args_defaults: tuple

    def __call__(self,
                 peaks,
                 records,
                 to_pe,
                 do_iterations=1,
                 min_area=0,
                 **kwargs):
        if not len(records) or not len(peaks) or not do_iterations:
            return peaks

        # Build the *args tuple for self.find_split_points from kwargs
        # since numba doesn't support **kwargs
        args_options = []
        for i, (k, value) in enumerate(self.find_split_args_defaults):
            if k in kwargs:
                value = kwargs[k]
            if k == 'threshold':
                # The 'threshold' option is a user-specified function
                value = value(peaks)
            args_options.append(value)
        args_options = tuple(args_options)

        # Check for spurious options
        argnames = [k for k, _ in self.find_split_args_defaults]
        for k in kwargs:
            if k not in argnames:
                raise TypeError(f"Unknown argument {k} for {self.__class__}")

        is_split = np.zeros(len(peaks), dtype=np.bool_)

        new_peaks = self.split_peaks(
            # Numba doesn't like self as argument, but it's ok with functions...
            split_finder=self.find_split_points,
            peaks=peaks,
            is_split=is_split,
            orig_dt=records[0]['dt'],
            min_area=min_area,
            args_options=tuple(args_options),
            result_dtype=peaks.dtype)

        if is_split.sum() != 0:
            # Found new peaks: compute basic properties
            strax.sum_waveform(new_peaks, records, to_pe)
            strax.compute_widths(new_peaks)

            # ... and recurse (if needed)
            new_peaks = self(new_peaks,
                             records,
                             to_pe,
                             do_iterations=do_iterations - 1,
                             min_area=min_area,
                             **kwargs)
            peaks = strax.sort_by_time(
                np.concatenate([peaks[~is_split], new_peaks]))

        return peaks

    @staticmethod
    @strax.growing_result(dtype=strax.peak_dtype(), chunk_size=int(1e4))
    @numba.jit(nopython=True, nogil=True)
    def split_peaks(split_finder,
                    peaks,
                    orig_dt,
                    is_split,
                    min_area,
                    args_options,
                    _result_buffer=None,
                    result_dtype=None):
        # TODO NEEDS TESTS!
        new_peaks = _result_buffer
        offset = 0

        for p_i, p in enumerate(peaks):
            if p['area'] < min_area:
                continue

            prev_split_i = 0
            w = p['data'][:p['length']]

            for split_i, bonus_output in split_finder(w, p['dt'], p_i,
                                                      *args_options):
                if split_i == NO_MORE_SPLITS:
                    p['max_goodness_of_split'] = bonus_output
                    # although the iteration will end anyway afterwards:
                    continue

                is_split[p_i] = True
                r = new_peaks[offset]
                r['time'] = p['time'] + prev_split_i * p['dt']
                r['channel'] = p['channel']
                # Set the dt to the original (lowest) dt first;
                # this may change when the sum waveform of the new peak
                # is computed
                r['dt'] = orig_dt
                r['length'] = (split_i - prev_split_i) * p['dt'] / orig_dt

                r['max_gap'] = -1  # Too lazy to compute this

                if r['length'] <= 0:
                    print(p['data'])
                    print(prev_split_i, split_i)
                    raise ValueError("Attempt to create invalid peak!")

                offset += 1
                if offset == len(new_peaks):
                    yield offset
                    offset = 0

                prev_split_i = split_i

        yield offset

    @staticmethod
    def find_split_points(w, *args_options):
        raise NotImplementedError
Esempio n. 16
0
 def infer_dtype(self):
     # We are going to check later that the infer_dtype is always called.
     dtype = strax.peak_dtype() + [(
         ('PMT with median most records', 'max_pmt'), np.int16)]
     self.dtype_is_set = True
     return dtype
Esempio n. 17
0
 def infer_dtype(self):
     return strax.peak_dtype(self.config['samples_per_record'])
Esempio n. 18
0
 def infer_dtype(self):
     self.to_pe = get_to_pe(self.run_id, self.config['to_pe_file'])
     return strax.peak_dtype(n_channels=len(self.to_pe))
Esempio n. 19
0
def test_core():
    mystrax = strax.Context(storage=[], register=[Records, Peaks])
    bla = mystrax.get_array(run_id=run_id, targets='peaks')
    assert len(bla) == recs_per_chunk * n_chunks
    assert bla.dtype == strax.peak_dtype()
Esempio n. 20
0
def show_time_range(st, run_id, t0, dt=10):
    from functools import partial

    import numpy as np
    import pandas as pd

    import holoviews as hv
    from holoviews.operation.datashader import datashade, dynspread
    hv.extension('bokeh')

    import strax

    import gc
    # Somebody thought it was a good idea to call gc.collect explicitly somewhere in holoviews
    # This makes dynamic PMT maps super slow
    # Until I trace the offender:
    gc.collect = lambda *args, **kwargs: None

    # Custom wheel zoom tool that only zooms in time
    from bokeh.models import WheelZoomTool
    time_zoom = WheelZoomTool(dimensions='width')

    # Get ADC->pe multiplicative conversion factor
    from pax.configuration import load_configuration
    from pax.dsputils import adc_to_pe
    pax_config = load_configuration('XENON1T')["DEFAULT"]
    to_pe = np.array(
        [adc_to_pe(pax_config, ch) for ch in range(pax_config['n_channels'])])

    tpc_r = pax_config['tpc_radius']

    # Get locations of PMTs
    r = []
    for q in pax_config['pmts']:
        r.append(
            dict(x=q['position']['x'],
                 y=q['position']['y'],
                 i=q['pmt_position'],
                 array=q.get('array', 'other')))
    f = 1.08
    pmt_locs = pd.DataFrame(r)

    records = st.get_array(run_id,
                           'raw_records',
                           time_range=(t0, t0 + int(1e10)))

    # TOOD: don't reprocess, just load...
    hits = strax.find_hits(records)
    peaks = strax.find_peaks(hits,
                             to_pe,
                             gap_threshold=300,
                             min_hits=3,
                             result_dtype=strax.peak_dtype(n_channels=260))
    strax.sum_waveform(peaks, records, to_pe)
    # Integral in pe
    areas = records['data'].sum(axis=1) * to_pe[records['channel']]

    def normalize_time(t):
        return (t - records[0]['time']) / 1e9

    # Create dataframe with record metadata
    df = pd.DataFrame(
        dict(area=areas,
             time=normalize_time(records['time']),
             channel=records['channel']))

    # Convert to holoviews Points
    points = hv.Points(
        df,
        kdims=[
            hv.Dimension('time', label='Time', unit='sec'),
            hv.Dimension('channel', label='PMT number', range=(0, 260))
        ],
        vdims=[
            hv.Dimension(
                'area',
                label='Area',
                unit='pe',
                # range=(0, 1000)
            )
        ])

    def pmt_map(t_0, t_1, array='top', **kwargs):
        # Compute the PMT pattern (fast)
        ps = points[(t_0 <= points['time']) & (points['time'] < t_1)]
        areas = np.bincount(ps['channel'],
                            weights=ps['area'],
                            minlength=len(pmt_locs))

        # Which PMTs should we include?
        pmt_mask = pmt_locs['array'] == array
        d = pmt_locs[pmt_mask].copy()
        d['area'] = areas[pmt_mask]

        # Convert to holoviews points
        d = hv.Dataset(d,
                       kdims=[
                           hv.Dimension('x',
                                        unit='cm',
                                        range=(-tpc_r * f, tpc_r * f)),
                           hv.Dimension('y',
                                        unit='cm',
                                        range=(-tpc_r * f, tpc_r * f)),
                           hv.Dimension('i', label='PMT number'),
                           hv.Dimension('area', label='Area', unit='PE')
                       ])

        return d.to(hv.Points,
                    vdims=['area', 'i'],
                    group='PMTPattern',
                    label=array.capitalize(),
                    **kwargs).opts(plot=dict(color_index=2,
                                             tools=['hover'],
                                             show_grid=False),
                                   style=dict(size=17, cmap='magma'))

    def pmt_map_range(x_range, array='top', **kwargs):
        # For use in dynamicmap with streams
        if x_range is None:
            x_range = (0, 0)
        return pmt_map(x_range[0], x_range[1], array=array, **kwargs)

    xrange_stream = hv.streams.RangeX(source=points)

    # TODO: weigh by area

    def channel_map():
        return dynspread(
            datashade(
                points, y_range=(0, 260),
                streams=[xrange_stream])).opts(plot=dict(
                    width=600,
                    tools=[time_zoom, 'xpan'],
                    default_tools=['save', 'pan', 'box_zoom', 'save', 'reset'],
                    show_grid=False))

    def plot_peak(p):
        # It's better to plot amplitude /time than per bin, since
        # sampling times are now variable
        y = p['data'][:p['length']] / p['dt']
        t_edges = np.arange(p['length'] + 1, dtype=np.int64)
        t_edges = t_edges * p['dt'] + p['time']
        t_edges = normalize_time(t_edges)

        # Correct step plotting from Knut
        t_ = np.zeros(2 * len(y))
        y_ = np.zeros(2 * len(y))
        t_[0::2] = t_edges[0:-1]
        t_[1::2] = t_edges[1::]
        y_[0::2] = y
        y_[1::2] = y

        c = hv.Curve(dict(time=t_, amplitude=y_),
                     kdims=points.kdims[0],
                     vdims=hv.Dimension('amplitude',
                                        label='Amplitude',
                                        unit='PE/ns'),
                     group='PeakSumWaveform')
        return c.opts(
            plot=dict(  # interpolation='steps-mid',
                # default_tools=['save', 'pan', 'box_zoom', 'save', 'reset'],
                # tools=[time_zoom, 'xpan'],
                width=600,
                shared_axes=False,
                show_grid=True),
            style=dict(color='b')
            # norm=dict(framewise=True)
        )

    def peaks_in(t_0, t_1):
        return peaks[(normalize_time(peaks['time'] +
                                     peaks['length'] * peaks['dt']) > t_0)
                     & (normalize_time(peaks['time']) < t_1)]

    def plot_peaks(t_0, t_1, n_max=10):
        # Find peaks in this range
        ps = peaks_in(t_0, t_1)
        # Show only the largest n_max peaks
        if len(ps) > n_max:
            areas = ps['area']
            max_area = np.sort(areas)[-n_max]
            ps = ps[areas >= max_area]

        return hv.Overlay(items=[plot_peak(p) for p in ps])

    def plot_peak_range(x_range, **kwargs):
        # For use in dynamicmap with streams
        if x_range is None:
            x_range = (0, 10)
        return plot_peaks(x_range[0], x_range[1], **kwargs)

    top_map = hv.DynamicMap(partial(pmt_map_range, array='top'),
                            streams=[xrange_stream])
    bot_map = hv.DynamicMap(partial(pmt_map_range, array='bottom'),
                            streams=[xrange_stream])
    waveform = hv.DynamicMap(plot_peak_range, streams=[xrange_stream])
    layout = waveform + top_map + channel_map() + bot_map
    return layout.cols(2)
Esempio n. 21
0
 def infer_dtype(self):
     return dict(peaklets=strax.peak_dtype(n_channels=straxen.n_tpc_pmts),
                 lone_hits=strax.hit_dtype)
def software_he_veto(records, to_pe,
                     area_threshold=int(1e5),
                     veto_length=int(3e6),
                     veto_res=int(1e3), pass_veto_fraction=0.01,
                     pass_veto_extend=3):
    """Veto veto_length (time in ns) after peaks larger than
    area_threshold (in PE).

    Further large peaks inside the veto regions are still passed:
    We sum the waveform inside the veto region (with time resolution
    veto_res in ns) and pass regions within pass_veto_extend samples
    of samples with amplitude above pass_veto_fraction times the maximum.

    :returns: (preserved records, vetoed records, veto intervals).

    :param records: PMT records
    :param to_pe: ADC to PE conversion factors for the channels in records.
    :param area_threshold: Minimum peak area to trigger the veto.
    Note we use a much rougher clustering than in later processing.
    :param veto_length: Time in ns to veto after the peak
    :param veto_res: Resolution of the sum waveform inside the veto region.
    Do not make too large without increasing integer type in some strax
    dtypes...
    :param pass_veto_fraction: fraction of maximum sum waveform amplitude to
    trigger veto passing of further peaks
    :param pass_veto_extend: samples to extend (left and right) the pass veto
    regions.
    """
    veto_res = int(veto_res)
    if veto_res > np.iinfo(np.int16).max:
        raise ValueError("Veto resolution does not fit 16-bit int")
    veto_length = np.ceil(veto_length / veto_res).astype(np.int) * veto_res
    veto_n = int(veto_length / veto_res) + 1

    # 1. Find large peaks in the data.
    # This will actually return big agglomerations of peaks and their tails
    peaks = strax.find_peaks(
        records, to_pe,
        gap_threshold=1,
        left_extension=0,
        right_extension=0,
        min_channels=100,
        min_area=area_threshold,
        result_dtype=strax.peak_dtype(n_channels=len(to_pe),
                                      n_sum_wv_samples=veto_n))

    # 2. Find initial veto regions around these peaks
    # (with a generous right extension)
    veto_start, veto_end = strax.find_peak_groups(
        peaks,
        gap_threshold=veto_length + 2 * veto_res,
        right_extension=veto_length,
        left_extension=veto_res)
    veto_end = veto_end.clip(0, strax.endtime(records[-1]))
    veto_length = veto_end - veto_start
    # dtype is like record (since we want to use hitfiding etc)
    # but with float32 waveform
    regions = np.zeros(
        len(veto_start),
        dtype=strax.interval_dtype + [
            ("data", (np.float32, veto_n)),
            ("baseline", np.float32),
            ("reduction_level", np.int64),
            ("record_i", np.int64),
            ("pulse_length", np.int64),
        ])
    regions['time'] = veto_start
    regions['length'] = veto_length
    regions['pulse_length'] = veto_length
    regions['dt'] = veto_res

    if not len(regions):
        # No veto anywhere in this data
        return records, records[:0], np.zeros(0, strax.hit_dtype)

    # 3. Find pass_veto regios with big peaks inside the veto regions.
    # For this we compute a rough sum waveform (at low resolution,
    # without looping over the pulse data)
    rough_sum(regions, records, to_pe, veto_n, veto_res)
    regions['data'] /= np.max(regions['data'], axis=1)[:, np.newaxis]
    pass_veto = strax.find_hits(regions, threshold=pass_veto_fraction)

    # 4. Extend these by a few samples and inverse to find veto regions
    regions['data'] = 1
    regions = strax.cut_outside_hits(
        regions,
        pass_veto,
        left_extension=pass_veto_extend,
        right_extension=pass_veto_extend)
    regions['data'] = 1 - regions['data']
    veto = strax.find_hits(regions, threshold=0.5)
    # Do not remove very tiny regions
    veto = veto[veto['length'] > 2 * pass_veto_extend]

    # 5. Apply the veto and return results
    veto_mask = strax.fully_contained_in(records, veto) == -1
    return tuple(list(_mask_and_not(records, veto_mask)) + [veto])
Esempio n. 23
0
def software_he_veto(records,
                     to_pe,
                     chunk_end,
                     area_threshold=int(1e5),
                     veto_length=int(3e6),
                     veto_res=int(1e3),
                     pass_veto_fraction=0.01,
                     pass_veto_extend=3,
                     max_veto_value=None):
    """Veto veto_length (time in ns) after peaks larger than
    area_threshold (in PE).

    Further large peaks inside the veto regions are still passed:
    We sum the waveform inside the veto region (with time resolution
    veto_res in ns) and pass regions within pass_veto_extend samples
    of samples with amplitude above pass_veto_fraction times the maximum.

    :returns: (preserved records, vetoed records, veto intervals).

    :param records: PMT records
    :param to_pe: ADC to PE conversion factors for the channels in records.
    :param chunk_end: Endtime of chunk to set as maximum ceiling for the veto period
    :param area_threshold: Minimum peak area to trigger the veto.
    Note we use a much rougher clustering than in later processing.
    :param veto_length: Time in ns to veto after the peak
    :param veto_res: Resolution of the sum waveform inside the veto region.
    Do not make too large without increasing integer type in some strax
    dtypes...
    :param pass_veto_fraction: fraction of maximum sum waveform amplitude to
    trigger veto passing of further peaks
    :param pass_veto_extend: samples to extend (left and right) the pass veto
    regions.
    :param max_veto_value: if not None, pass peaks that exceed this area
    no matter what.
    """
    veto_res = int(veto_res)
    if veto_res > np.iinfo(np.int16).max:
        raise ValueError("Veto resolution does not fit 16-bit int")
    veto_length = np.ceil(veto_length / veto_res).astype(np.int) * veto_res
    veto_n = int(veto_length / veto_res) + 1

    # 1. Find large peaks in the data.
    # This will actually return big agglomerations of peaks and their tails
    peaks = strax.find_peaks(records,
                             to_pe,
                             gap_threshold=1,
                             left_extension=0,
                             right_extension=0,
                             min_channels=100,
                             min_area=area_threshold,
                             result_dtype=strax.peak_dtype(
                                 n_channels=len(to_pe),
                                 n_sum_wv_samples=veto_n))

    # 2a. Set 'candidate regions' at these peaks. These should:
    #  - Have a fixed maximum length (else we can't use the strax hitfinder on them)
    #  - Never extend beyond the current chunk
    #  - Do not overlap
    veto_start = peaks['time']
    veto_end = np.clip(peaks['time'] + veto_length, None, chunk_end)
    veto_end[:-1] = np.clip(veto_end[:-1], None, veto_start[1:])

    # 2b. Convert these into strax record-like objects
    # Note the waveform is float32 though (it's a summed waveform)
    regions = np.zeros(len(veto_start),
                       dtype=strax.interval_dtype + [
                           ("data", (np.float32, veto_n)),
                           ("baseline", np.float32),
                           ("baseline_rms", np.float32),
                           ("reduction_level", np.int64),
                           ("record_i", np.int64),
                           ("pulse_length", np.int64),
                       ])
    regions['time'] = veto_start
    regions['length'] = (veto_end - veto_start) // veto_n
    regions['pulse_length'] = veto_n
    regions['dt'] = veto_res

    if not len(regions):
        # No veto anywhere in this data
        return records, records[:0], np.zeros(0, strax.hit_dtype)

    # 3. Find pass_veto regios with big peaks inside the veto regions.
    # For this we compute a rough sum waveform (at low resolution,
    # without looping over the pulse data)
    rough_sum(regions, records, to_pe, veto_n, veto_res)
    if max_veto_value is not None:
        pass_veto = strax.find_hits(regions, min_amplitude=max_veto_value)
    else:
        regions['data'] /= np.max(regions['data'], axis=1)[:, np.newaxis]
        pass_veto = strax.find_hits(regions, min_amplitude=pass_veto_fraction)

    # 4. Extend these by a few samples and inverse to find veto regions
    regions['data'] = 1
    regions = strax.cut_outside_hits(regions,
                                     pass_veto,
                                     left_extension=pass_veto_extend,
                                     right_extension=pass_veto_extend)
    regions['data'] = 1 - regions['data']
    veto = strax.find_hits(regions, min_amplitude=1)
    # Do not remove very tiny regions
    veto = veto[veto['length'] > 2 * pass_veto_extend]

    # 5. Apply the veto and return results
    veto_mask = strax.fully_contained_in(records, veto) == -1
    return tuple(list(mask_and_not(records, veto_mask)) + [veto])
Esempio n. 24
0
 def infer_dtype(self):
     self.dtype = strax.peak_dtype(n_channels=self.config['context_option'])
     return self.dtype
Esempio n. 25
0
class PeakSplitter:
    """Split peaks into more peaks based on arbitrary algorithm.
    :param peaks: Original peaks. Sum waveform must have been built
    and properties must have been computed (if you use them).
    :param records: Records from which peaks were built.
    :param to_pe: ADC to PE conversion factor array (of n_channels).
    :param data_type: 'peaks' or 'hitlets'. Specifies whether to use
        sum_waveform or get_hitlets_data to compute the waveform of the
        new split peaks/hitlets.
    :param next_ri: Index of next record for current record record_i.
        None if not needed.
    :param do_iterations: maximum number of times peaks are recursively split.
    :param min_area: Minimum area to do split. Smaller peaks are not split.

    The function find_split_points(), implemented in each subclass
    defines the algorithm, which takes in a peak's waveform and
    returns the index to split the peak at, if a split point is
    found. Otherwise NO_MORE_SPLITS is returned and the peak is
    left as is.
    """
    find_split_args_defaults: tuple

    def __call__(self, peaks, records, to_pe, data_type,
                 next_ri=None, do_iterations=1, min_area=0, **kwargs):
        if not len(records) or not len(peaks) or not do_iterations:
            return peaks

        # Build the *args tuple for self.find_split_points from kwargs
        # since numba doesn't support **kwargs
        args_options = []
        for i, (k, value) in enumerate(self.find_split_args_defaults):
            if k in kwargs:
                value = kwargs[k]
            if k == 'threshold':
                # The 'threshold' option is a user-specified function
                value = value(peaks)
            args_options.append(value)
        args_options = tuple(args_options)

        # Check for spurious options
        argnames = [k for k, _ in self.find_split_args_defaults]
        for k in kwargs:
            if k not in argnames:
                raise TypeError(f"Unknown argument {k} for {self.__class__}")

        is_split = np.zeros(len(peaks), dtype=np.bool_)

        split_function = {'peaks': self._split_peaks,
                          'hitlets': self._split_hitlets}
        if data_type not in split_function:
            raise ValueError(f'Data_type "{data_type}" is not supported.')

        new_peaks = split_function[data_type](
            # Numba doesn't like self as argument, but it's ok with functions...
            split_finder=self.find_split_points,
            peaks=peaks,
            is_split=is_split,
            orig_dt=records[0]['dt'],
            min_area=min_area,
            args_options=tuple(args_options),
            result_dtype=peaks.dtype)

        if is_split.sum() != 0:
            # Found new peaks: compute basic properties
            if data_type == 'peaks':
                strax.sum_waveform(new_peaks, records, to_pe)
            elif data_type == 'hitlets':
                # Add record fields here
                strax.update_new_hitlets(new_peaks, records, next_ri, to_pe)

            strax.compute_widths(new_peaks)

            # ... and recurse (if needed)
            new_peaks = self(new_peaks, records, to_pe, data_type, next_ri,
                             do_iterations=do_iterations - 1,
                             min_area=min_area, **kwargs)
            peaks = strax.sort_by_time(np.concatenate([peaks[~is_split],
                                                       new_peaks]))

        return peaks

    @staticmethod
    @strax.growing_result(dtype=strax.peak_dtype(), chunk_size=int(1e4))
    @numba.jit(nopython=True, nogil=True)
    def _split_peaks(split_finder, peaks, orig_dt, is_split, min_area,
                     args_options,
                     _result_buffer=None, result_dtype=None):
        """Loop over peaks, pass waveforms to algorithm, construct
        new peaks if and where a split occurs.
        """
        # NB: code very similar to _split_hitlets see
        # github.com/AxFoundation/strax/pull/309 for more info. Keep in mind
        # that changing one function should also be reflected in the other.
        new_peaks = _result_buffer
        offset = 0

        for p_i, p in enumerate(peaks):
            if p['area'] < min_area:
                continue

            prev_split_i = 0
            w = p['data'][:p['length']]
            for split_i, bonus_output in split_finder(
                    w, p['dt'], p_i, *args_options):
                if split_i == NO_MORE_SPLITS:
                    p['max_goodness_of_split'] = bonus_output
                    # although the iteration will end anyway afterwards:
                    continue

                is_split[p_i] = True
                r = new_peaks[offset]
                r['time'] = p['time'] + prev_split_i * p['dt']
                r['channel'] = p['channel']
                # Set the dt to the original (lowest) dt first;
                # this may change when the sum waveform of the new peak
                # is computed
                r['dt'] = orig_dt
                r['length'] = (split_i - prev_split_i) * p['dt'] / orig_dt
                r['max_gap'] = -1  # Too lazy to compute this
                if r['length'] <= 0:
                    print(p['data'])
                    print(prev_split_i, split_i)
                    raise ValueError("Attempt to create invalid peak!")

                offset += 1
                if offset == len(new_peaks):
                    yield offset
                    offset = 0

                prev_split_i = split_i

        yield offset

    @staticmethod
    @strax.growing_result(dtype=strax.hitlet_dtype(), chunk_size=int(1e4))
    @numba.jit(nopython=True, nogil=True)
    def _split_hitlets(split_finder, peaks, orig_dt, is_split, min_area,
                       args_options,
                       _result_buffer=None, result_dtype=None):
        """Loop over hits, pass waveforms to algorithm, construct
        new hits if and where a split occurs.
        """
        # TODO NEEDS TESTS!
        # NB: code very similar to _split_peaks see
        # github.com/AxFoundation/strax/pull/309 for more info. Keep in mind
        # that changing one function should also be reflected in the other.
        new_hits = _result_buffer
        offset = 0

        for h_i, h in enumerate(peaks):
            if h['area'] < min_area:
                continue

            prev_split_i = 0
            w = h['data'][:h['length']]
            for split_i, bonus_output in split_finder(
                    w, h['dt'], h_i, *args_options):
                if split_i == NO_MORE_SPLITS:
                    continue

                is_split[h_i] = True
                r = new_hits[offset]
                r['time'] = h['time'] + prev_split_i * h['dt']
                r['channel'] = h['channel']
                # Hitlet specific
                r['record_i'] = h['record_i']
                # Set the dt to the original (lowest) dt first;
                # this may change when the sum waveform of the new peak
                # is computed
                r['dt'] = orig_dt
                r['length'] = (split_i - prev_split_i) * h['dt'] / orig_dt
                if r['length'] <= 0:
                    print(h['data'])
                    print(prev_split_i, split_i)
                    raise ValueError("Attempt to create invalid hitlet!")

                offset += 1
                if offset == len(new_hits):
                    yield offset
                    offset = 0

                prev_split_i = split_i

        yield offset

    @staticmethod
    def find_split_points(w, dt, peak_i, *args_options):
        """This function is overwritten by LocalMinimumSplitter or LocalMinimumSplitter
        bare PeakSplitter class is not implemented"""
        raise NotImplementedError
Esempio n. 26
0
    Min_height is in pe/ns (NOT pe/bin!)
    """
    is_split = np.zeros(len(peaks), dtype=np.bool_)

    new_peaks = _split_peaks(peaks,
                             min_height=min_height,
                             min_ratio=min_ratio,
                             orig_dt=records[0]['dt'],
                             is_split=is_split,
                             result_dtype=peaks.dtype)
    strax.sum_waveform(new_peaks, records, to_pe)
    return strax.sort_by_time(np.concatenate([peaks[~is_split], new_peaks]))


@strax.utils.growing_result(dtype=strax.peak_dtype(), chunk_size=int(1e4))
@numba.jit(nopython=True, nogil=True, cache=True)
def _split_peaks(peaks,
                 min_height,
                 min_ratio,
                 orig_dt,
                 is_split,
                 _result_buffer=None,
                 result_dtype=None):
    new_peaks = _result_buffer
    offset = 0

    for p_i, p in enumerate(peaks):
        prev_split_i = 0

        for split_i in find_split_points(p['data'][:p['length']],
Esempio n. 27
0
 def infer_dtype(self):
     # Loading here another config which will be different for child:
     self.dtype = strax.peak_dtype(
         n_channels=self.config['context_option_child'])
     return self.dtype