def test_overlap_indices(a1, n_a, b1, n_b): a2 = a1 + n_a b2 = b1 + n_b (a_start, a_end), (b_start, b_end) = strax.overlap_indices(a1, n_a, b1, n_b) assert a_end - a_start == b_end - b_start, "Overlap must be equal length" assert a_end >= a_start, "Overlap must be nonnegative" if n_a == 0 or n_b == 0: assert a_start == a_end == b_start == b_end == 0 return a_filled = np.arange(a1, a2) b_filled = np.arange(b1, b2) true_overlap = np.intersect1d(a_filled, b_filled) if not len(true_overlap): assert a_start == a_end == b_start == b_end == 0 return true_a_inds = np.searchsorted(a_filled, true_overlap) true_b_inds = np.searchsorted(b_filled, true_overlap) found = (a_start, a_end), (b_start, b_end) expected = ((true_a_inds[0], true_a_inds[-1] + 1), (true_b_inds[0], true_b_inds[-1] + 1)) assert found == expected
def _get_thing_data(thing, container): """ Function which returns data for some overlapping indices of a thing in a container. Note: Thing must be of the interval dtype kind. """ overlap_hit_i, overlap_record_i = strax.overlap_indices( thing['time'] // thing['dt'], thing['length'], container['time'] // container['dt'], container['length']) data = container['data'][overlap_record_i[0]:overlap_record_i[1]] return data, overlap_hit_i
def _records_to_matrix(records, t0, window, n_channels, dt=10): n_samples = window // dt y = np.zeros((n_samples, n_channels), dtype=np.int32) for r in records: if r['channel'] > n_channels: continue (r_start, r_end), (y_start, y_end) = strax.overlap_indices( r['time'] // dt, r['length'], t0 // dt, n_samples) # += is paranoid, data in individual channels should not overlap # but... https://github.com/AxFoundation/strax/issues/119 y[y_start:y_end, r['channel']] += r['data'][r_start:r_end] return y
def _records_to_matrix_inner(records, t0, window, n_channels, dt=10): n_samples = (window // dt) + 1 # Use 32-bit integers, so downsampling saturated samples doesn't # cause wraparounds y = np.zeros((n_samples, n_channels), dtype=np.int32) if not len(records): return y samples_per_record = len(records[0]['data']) for r in records: if r['channel'] > n_channels: continue if dt >= samples_per_record * r['dt']: # Downsample to single sample -> store area idx = (r['time'] - t0) // dt if idx >= len(y): print(len(y), idx) raise IndexError('Despite n_samples = window // dt + 1, our ' 'idx is too high?!') y[idx, r['channel']] += r['area'] continue # Assume out-of-bounds data has been zeroed, so we do not # need to do r['data'][:r['length']] here. # This simplifies downsampling. w = r['data'].astype(np.int32) if dt > r['dt']: # Downsample duration = samples_per_record * r['dt'] if duration % dt != 0: raise ValueError("Cannot downsample fractionally") # .astype here keeps numba happy ... ?? w = w.reshape(duration // dt, -1).sum(axis=1).astype(np.int32) elif dt < r['dt']: raise ValueError("Upsampling not yet implemented") (r_start, r_end), (y_start, y_end) = strax.overlap_indices( r['time'] // dt, len(w), t0 // dt, n_samples) # += is paranoid, data in individual channels should not overlap # but... https://github.com/AxFoundation/strax/issues/119 y[y_start:y_end, r['channel']] += w[r_start:r_end] return y
def _get_hitlets_data(hitlets, records, to_pe): rranges = _touching_windows(records['time'], strax.endtime(records), hitlets['time'], strax.endtime(hitlets)) for i, h in enumerate(hitlets): recorded_samples_offset = 0 n_recorded_samples = 0 is_first_record = True for ind, r_ind in enumerate(range(rranges[i][0], rranges[i][1])): r = records[r_ind] if r['channel'] != h['channel']: continue (r_start, r_end), (h_start, h_end) = strax.overlap_indices( r['time'] // r['dt'], r['length'], h['time'] // h['dt'], h['length']) if (r_end - r_start) == 0 and (h_end - h_start) == 0: # _touching_windows will give a range of overlapping records with hitlet # independent of channel. Hence, in rare cases it might be that a record of # channel A touches with a hitlet of channel B which starts before the previous # record of channel b. Hence we get one non-overlapping record in channel b. continue if is_first_record: # We need recorded_samples_offset because hits may extend beyond the boundaries # of our recorded data. As the data is not defined in those regions we have to # chop and realign our data. See the following Example: (fragment 0, 1) [2, 2, 2, # 2] [2, 2, 2] with a hitfinder threshold of 1 and left/right extension of 3. In # the first fragment our hitlet would range from 3 to 8 in the second from 8 to # 11. Hence we have to subtract from every h_start and h_end the offset of 3 to # realign our data. Time and length of the hitlet are updated accordingly. is_first_record = False recorded_samples_offset = h_start h_start -= recorded_samples_offset h_end -= recorded_samples_offset h['data'][ h_start:h_end] += r['data'][r_start:r_end] + r['baseline'] % 1 n_recorded_samples += r_end - r_start # Chop time and length in case hit extends into non-recorded regions. h['time'] += int(recorded_samples_offset * h['dt']) h['length'] = n_recorded_samples h['data'][:] = h['data'][:] * to_pe[h['channel']] h['area'] = np.sum(h['data'])
def test_sum_waveform(records): # Make a single big peak to contain all the records n_ch = 100 rlinks = strax.record_links(records) hits = strax.find_hits(records, np.ones(n_ch)) hits['left_integration'] = hits['left'] hits['right_integration'] = hits['right'] hits = strax.sort_by_time(hits) peaks = strax.find_peaks(hits, np.ones(n_ch), gap_threshold=6, left_extension=2, right_extension=3, min_area=0, min_channels=1, max_duration=10_000_000) strax.sum_waveform(peaks, hits, records, rlinks, np.ones(n_ch)) for p in peaks: # Area measures must be consistent area = p['area'] assert area >= 0 assert p['data'].sum() == area assert p['area_per_channel'].sum() == area sum_wv = np.zeros(p['length'], dtype=np.float32) for r in records: (rs, re), (ps, pe) = strax.overlap_indices(r['time'], r['length'], p['time'], p['length']) sum_wv[ps:pe] += r['data'][rs:re] assert np.all(p['data'][:p['length']] == sum_wv) # Finally check that we also can use a selection of peaks to sum strax.sum_waveform(peaks, hits, records, rlinks, np.ones(n_ch), select_peaks_indices=np.array([0]))
def _cut_outside_hits(records, hits, new_recs, left_extension=2, right_extension=15): if not len(records): return samples_per_record = len(records[0]['data']) previous_record, next_record = record_links(records) for hit_i, h in enumerate(hits): rec_i = h['record_i'] r = records[rec_i] # Indices to keep, with 0 at the start of this record start_keep = h['left'] - left_extension end_keep = h['right'] + right_extension # Indices of samples to keep in this record (a, b), _ = strax.overlap_indices(0, r['length'], start_keep, end_keep - start_keep) new_recs[rec_i]['data'][a:b] = records[rec_i]['data'][a:b] # Keep samples in previous record, if there was one if start_keep < 0: prev_ri = previous_record[rec_i] if prev_ri != NO_RECORD_LINK: # Note start_keep is negative, so this keeps the # last few samples of the previous record a_prev = start_keep new_recs[prev_ri]['data'][a_prev:] = \ records[prev_ri]['data'][a_prev:] # Same for the next record, if there is one if end_keep > samples_per_record: next_ri = next_record[rec_i] if next_ri != NO_RECORD_LINK: b_next = end_keep - samples_per_record new_recs[next_ri]['data'][:b_next] = \ records[next_ri]['data'][:b_next]
def _build_hit_waveform(hit, record, hit_waveform): """ Adds information for overlapping record and hit to hit_waveform. Updates hit_waveform inplace. Result is still in ADC counts. :returns: Boolean if record saturated within the hit. """ (h_start_record, h_end_record), (r_start, r_end) = strax.overlap_indices( hit['time'] // hit['dt'], hit['length'], record['time'] // record['dt'], record['length']) # Get record properties: record_data = record['data'][r_start:r_end] multiplier = 2**record['amplitude_bit_shift'] bl_fpart = record['baseline'] % 1 max_in_record = record_data.max() * multiplier # Build hit waveform: hit_waveform[h_start_record:h_end_record] = (multiplier * record_data + bl_fpart) return np.int8(max_in_record >= np.int16(record['baseline']))
def _peak_saturation_correction_inner( channel_saturated, records, p, to_pe, b_sumwf, b_pulse, b_index, reference_length=100, min_reference_length=20, ): """Would add a third level loop in peak_saturation_correction Which is not ideal for numba, thus this function is written :param channel_saturated: (bool, n_channels) :param p: One peak/peaklet :param to_pe: adc to PE conversion (length should equal number of PMTs) :param b_sumwf, b_pulse, b_index: Filled buffers """ dt = records['dt'][0] n_channels = len(channel_saturated) for ch in range(n_channels): if not channel_saturated[ch]: continue b = b_pulse[ch] r0 = records[b_index[ch][0]] # Define the reference region as reference_length before the first saturation point # unless there are not enough samples bl = np.inf for record_i in b_index[ch]: if record_i == -1: break bl = min(bl, records['baseline'][record_i]) s0 = np.argmax(b >= np.int16(bl)) ref = slice(max(0, s0 - reference_length), s0) if (b[ref] * to_pe[ch] > 1).sum() < min_reference_length: # the pulse is saturated, but there are not enough reference samples to get a good ratio # This actually distinguished between S1 and S2 and will only correct S2 signals continue if (b_sumwf[ref] > 1).sum() < min_reference_length: # the same condition applies to the waveform model continue if np.sum(b[ref]) * to_pe[ch] / np.sum(b_sumwf[ref]) > 1: # The pulse is saturated, but insufficient information is available in the other channels # to reliably reconstruct it continue scale = np.sum(b[ref]) / np.sum(b_sumwf[ref]) # Loop over the record indices of the saturated channel (saved in b_index buffer) for record_i in b_index[ch]: if record_i == -1: break r = records[record_i] r_slice, b_slice = strax.overlap_indices( r['time'] // dt, r['length'], p['time'] // dt + s0, p['length'] * p['dt'] // dt - s0) if r_slice[1] == r_slice[0]: # This record proceeds saturation continue b_slice = b_slice[0] + s0, b_slice[1] + s0 # First is finding the highest point in the desaturated record # because we need to bit shift the whole record if it exceeds int16 range apax = scale * max(b_sumwf[slice(*b_slice)]) if np.int32(apax) >= 2**15: # int16(2**15) is -2**15 bshift = int(np.floor(np.log2(apax) - 14)) tmp = r['data'].astype(np.int32) tmp[slice(*r_slice)] = b_sumwf[slice(*b_slice)] * scale r['area'] = np.sum(tmp) # Auto covert to int64 r['data'][:] = np.right_shift(tmp, bshift) r['amplitude_bit_shift'] += bshift else: r['data'][slice(*r_slice)] = b_sumwf[slice(*b_slice)] * scale r['area'] = np.sum(r['data'])
def peak_saturation_correction( records, rlinks, peaks, hitlets, to_pe, reference_length=100, min_reference_length=20, use_classification=False, ): """Correct the area and per pmt area of peaks from saturation :param records: Records :param rlinks: strax.record_links of corresponding records. :param peaks: Peaklets / Peaks :param hitlets: Hitlets found in records to build peaks. (Hitlets are hits including the left/right extension) :param to_pe: adc to PE conversion (length should equal number of PMTs) :param reference_length: Maximum number of reference sample used to correct saturated samples :param min_reference_length: Minimum number of reference sample used to correct saturated samples :param use_classification: Option of using classification to pick only S2 """ if not len(records): return if not len(peaks): return # Search for peaks with saturated channels mask = peaks['n_saturated_channels'] > 0 if use_classification: mask &= peaks['type'] == 2 peak_list = np.where(mask)[0] # Look up records that touch each peak record_ranges = _touching_windows(records['time'], strax.endtime(records), peaks[peak_list]['time'], strax.endtime(peaks[peak_list])) # Create temporary arrays for calculation dt = records[0]['dt'] n_channels = len(peaks[0]['saturated_channel']) len_buffer = np.max(peaks['length'] * peaks['dt']) // dt + 1 max_nrecord = len_buffer // len(records[0]['data']) + 1 # Buff the sum wf [pe] of non-saturated channels b_sumwf = np.zeros(len_buffer, dtype=np.float32) # Buff the records 'data' [ADC] in saturated channels b_pulse = np.zeros((n_channels, len_buffer), dtype=np.int16) # Buff the corresponding record index of saturated channels b_index = np.zeros((n_channels, max_nrecord), dtype=np.int64) # Main for ix, peak_i in enumerate(peak_list): # reset buffers b_sumwf[:] = 0 b_pulse[:] = 0 b_index[:] = -1 p = peaks[peak_i] channel_saturated = p['saturated_channel'] > 0 for record_i in range(record_ranges[ix][0], record_ranges[ix][1]): r = records[record_i] r_slice, b_slice = strax.overlap_indices( r['time'] // dt, r['length'], p['time'] // dt, p['length'] * p['dt'] // dt) ch = r['channel'] if channel_saturated[ch]: b_pulse[ch, slice(*b_slice)] += r['data'][slice(*r_slice)] b_index[ch, np.argmin(b_index[ch])] = record_i else: b_sumwf[slice(*b_slice)] += r['data'][slice(*r_slice)] \ * to_pe[ch] _peak_saturation_correction_inner(channel_saturated, records, p, to_pe, b_sumwf, b_pulse, b_index, reference_length, min_reference_length) # Back track sum wf downsampling peaks[peak_i]['length'] = p['length'] * p['dt'] / dt peaks[peak_i]['dt'] = dt strax.sum_waveform(peaks, hitlets, records, rlinks, to_pe, peak_list) return peak_list
def sum_waveform(peaks, hits, records, record_links, adc_to_pe, select_peaks_indices=None): """Compute sum waveforms for all peaks in peaks. Only builds summed waveform other regions in which hits were found. This is required to avoid any bias due to zero-padding and baselining. Will downsample sum waveforms if they do not fit in per-peak buffer :param peaks: Peaks for which the summed waveform should be build. :param hits: Hits which are inside peaks. Must be sorted according to record_i. :param records: Records to be used to build peaks. :param record_links: Tuple of previous and next records. :param select_peaks_indices: Indices of the peaks for partial processing. In the form of np.array([np.int, np.int, ..]). If None (default), all the peaks are used for the summation. Assumes all peaks AND pulses have the same dt! """ if not len(records): return if not len(peaks): return if select_peaks_indices is None: select_peaks_indices = np.arange(len(peaks)) if not len(select_peaks_indices): return dt = records[0]['dt'] n_samples_record = len(records[0]['data']) prev_record_i, next_record_i = record_links # Big buffer to hold even largest sum waveforms # Need a little more even for downsampling.. swv_buffer = np.zeros(peaks['length'].max() * 2, dtype=np.float32) n_channels = len(peaks[0]['area_per_channel']) area_per_channel = np.zeros(n_channels, dtype=np.float32) # Hit index for hits in peaks left_h_i = 0 # Create hit waveform buffer hit_waveform = np.zeros(hits['length'].max(), dtype=np.float32) for peak_i in select_peaks_indices: p = peaks[peak_i] # Clear the relevant part of the swv buffer for use # (we clear a bit extra for use in downsampling) p_length = p['length'] swv_buffer[:min(2 * p_length, len(swv_buffer))] = 0 # Clear area and area per channel # (in case find_peaks already populated them) area_per_channel *= 0 p['area'] = 0 # Find first hit that contributes to this peak for left_h_i in range(left_h_i, len(hits)): h = hits[left_h_i] # TODO: need test that fails if we replace < with <= here if p['time'] < h['time'] + h['length'] * dt: break else: # Hits exhausted before peaks exhausted # TODO: this is a strange case, maybe raise warning/error? break # Scan over hits that overlap with peak for right_h_i in range(left_h_i, len(hits)): h = hits[right_h_i] record_i = h['record_i'] ch = h['channel'] assert p['dt'] == h['dt'], "Hits and peaks must have same dt" shift = (p['time'] - h['time']) // dt n_samples_hit = h['length'] n_samples_peak = p_length if shift <= -n_samples_peak: # Hit is completely to the right of the peak; # we've seen all overlapping records break if n_samples_hit <= shift: # The (real) data in this record does not actually overlap # with the peak # (although a previous, longer hit did overlap) continue # Get overlapping samples between hit and peak: (h_start, h_end), (p_start, p_end) = strax.overlap_indices( h['time'] // dt, n_samples_hit, p['time'] // dt, n_samples_peak) hit_waveform[:] = 0 # Get record which belongs to main part of hit (wo integration bounds): r = records[record_i] is_saturated = _build_hit_waveform(h, r, hit_waveform) # Now check if we also have to go to prev/next record due to integration bounds. # If bounds are outside of peak we chop when building the summed waveform later. if h['left_integration'] < 0 and prev_record_i[record_i] != -1: r = records[prev_record_i[record_i]] is_saturated |= _build_hit_waveform(h, r, hit_waveform) if h['right_integration'] > n_samples_record and next_record_i[ record_i] != -1: r = records[next_record_i[record_i]] is_saturated |= _build_hit_waveform(h, r, hit_waveform) p['saturated_channel'][ch] |= is_saturated hit_data = hit_waveform[h_start:h_end] hit_data *= adc_to_pe[ch] swv_buffer[p_start:p_end] += hit_data area_pe = hit_data.sum() area_per_channel[ch] += area_pe p['area'] += area_pe store_downsampled_waveform(p, swv_buffer) p['n_saturated_channels'] = p['saturated_channel'].sum() p['area_per_channel'][:] = area_per_channel
def sum_waveform(peaks, records, adc_to_pe, select_peaks_indices=None): """Compute sum waveforms for all peaks in peaks Will downsample sum waveforms if they do not fit in per-peak buffer :arg select_peaks_indices: Indices of the peaks for partial processing. In the form of np.array([np.int, np.int, ..]). If None (default), all the peaks are used for the summation. Assumes all peaks AND pulses have the same dt! """ if not len(records): return if not len(peaks): return if select_peaks_indices is None: select_peaks_indices = np.arange(len(peaks)) if not len(select_peaks_indices): return dt = records[0]['dt'] # Big buffer to hold even largest sum waveforms # Need a little more even for downsampling.. swv_buffer = np.zeros(peaks['length'].max() * 2, dtype=np.float32) # Index of first record that could still contribute to subsequent peaks # Records before this do not need to be considered anymore left_r_i = 0 n_channels = len(peaks[0]['area_per_channel']) area_per_channel = np.zeros(n_channels, dtype=np.float32) for peak_i in select_peaks_indices: p = peaks[peak_i] # Clear the relevant part of the swv buffer for use # (we clear a bit extra for use in downsampling) p_length = p['length'] swv_buffer[:min(2 * p_length, len(swv_buffer))] = 0 # Clear area and area per channel # (in case find_peaks already populated them) area_per_channel *= 0 p['area'] = 0 # Find first record that contributes to this peak for left_r_i in range(left_r_i, len(records)): r = records[left_r_i] # TODO: need test that fails if we replace < with <= here if p['time'] < r['time'] + r['length'] * dt: break else: # Records exhausted before peaks exhausted # TODO: this is a strange case, maybe raise warning/error? break # Scan over records that overlap for right_r_i in range(left_r_i, len(records)): r = records[right_r_i] ch = r['channel'] multiplier = 2**r['amplitude_bit_shift'] assert p['dt'] == r['dt'], "Records and peaks must have same dt" shift = (p['time'] - r['time']) // dt n_r = r['length'] n_p = p_length if shift <= -n_p: # Record is completely to the right of the peak; # we've seen all overlapping records break if n_r <= shift: # The (real) data in this record does not actually overlap # with the peak # (although a previous, longer record did overlap) continue (r_start, r_end), (p_start, p_end) = strax.overlap_indices( r['time'] // dt, n_r, p['time'] // dt, n_p) max_in_record = r['data'][r_start:r_end].max() * multiplier p['saturated_channel'][ch] |= np.int8( max_in_record >= r['baseline']) bl_fpart = r['baseline'] % 1 # TODO: check numba does casting correctly here! pe_waveform = adc_to_pe[ch] * ( multiplier * r['data'][r_start:r_end] + bl_fpart) swv_buffer[p_start:p_end] += pe_waveform area_pe = pe_waveform.sum() area_per_channel[ch] += area_pe p['area'] += area_pe store_downsampled_waveform(p, swv_buffer) p['n_saturated_channels'] = p['saturated_channel'].sum() p['area_per_channel'][:] = area_per_channel
def sum_waveform(peaks, records, adc_to_pe, n_channels=248): """Compute sum waveforms for all peaks in peaks Will downsample sum waveforms if they do not fit in per-peak buffer :param n_channels: Number of channels that contribute to the total area and n_saturated_channels. For further channels we still calculate area_per_channel and saturated_channel. Assumes all peaks AND pulses have the same dt! """ if not len(records): return if not len(peaks): return dt = records[0]['dt'] sum_wv_samples = len(peaks[0]['data']) # Big buffer to hold even largest sum waveforms # Need a little more even for downsampling.. swv_buffer = np.zeros(peaks['length'].max() * 2, dtype=np.float32) # Index of first record that could still contribute to subsequent peaks # Records before this do not need to be considered anymore left_r_i = 0 n_channels = len(peaks[0]['area_per_channel']) area_per_channel = np.zeros(n_channels, dtype=np.float32) for peak_i, p in enumerate(peaks): # Clear the relevant part of the swv buffer for use # (we clear a bit extra for use in downsampling) p_length = p['length'] swv_buffer[:min(2 * p_length, len(swv_buffer))] = 0 # Clear area and area per channel # (in case find_peaks already populated them) area_per_channel *= 0 p['area'] = 0 # Find first record that contributes to this peak for left_r_i in range(left_r_i, len(records)): r = records[left_r_i] # TODO: need test that fails if we replace < with <= here if p['time'] < r['time'] + r['length'] * dt: break else: # Records exhausted before peaks exhausted # TODO: this is a strange case, maybe raise warning/error? break # Scan over records that overlap for right_r_i in range(left_r_i, len(records)): r = records[right_r_i] ch = r['channel'] assert p['dt'] == r['dt'], "Records and peaks must have same dt" shift = (p['time'] - r['time']) // dt n_r = r['length'] n_p = p_length if shift <= -n_p: # Record is completely to the right of the peak; # we've seen all overlapping records break if n_r <= shift: # The (real) data in this record does not actually overlap # with the peak # (although a previous, longer record did overlap) continue (r_start, r_end), (p_start, p_end) = strax.overlap_indices( r['time'] // dt, n_r, p['time'] // dt, n_p) max_in_record = r['data'][r_start:r_end].max() p['saturated_channel'][ch] = int(max_in_record >= r['baseline']) bl_fpart = r['baseline'] % 1 # TODO: check numba does casting correctly here! pe_waveform = adc_to_pe[ch] * (r['data'][r_start:r_end] + bl_fpart) swv_buffer[p_start:p_end] += pe_waveform area_pe = pe_waveform.sum() area_per_channel[ch] += area_pe p['area'] += area_pe # Store the sum waveform # Do we need to downsample the swv to store it? downs_f = int(np.ceil(p_length / sum_wv_samples)) if downs_f > 1: # Compute peak length after downsampling. # We floor rather than ceil here, potentially cutting off # some samples from the right edge of the peak. # If we would ceil, the peak could grow larger and # overlap with a subsequent next peak, crashing strax later. new_ns = p['length'] = int(np.floor(p_length / downs_f)) p['data'][:new_ns] = \ swv_buffer[:new_ns * downs_f].reshape(-1, downs_f).sum(axis=1) p['dt'] *= downs_f else: p['data'][:p_length] = swv_buffer[:p_length] # Store the saturation count and area per channel p['n_saturated_channels'] = p['saturated_channel'].sum() p['area_per_channel'][:] = area_per_channel