def estimate_noise(arr, lc=300, hc=6000, num_channels=4, fs=3e4, microvolt_factor=0.195, ne_bin_s=1): """Calulate MAD (mean absolute deviation) of high pass filtered array. Returns list of bin-sized estimates in uV. """ ne_bin_size = int(ne_bin_s * fs) # noise estimation bin size # Filter b, a = butter_bandpass(lc, hc, fs) batches = get_batches(arr.shape[0], ne_bin_size) ne = np.zeros((len(batches), num_channels)) nfac = 1 / 0.6745 # Calculate MAD (mean absolute deviation) over chunks for n, batch in enumerate(tqdm(batches, leave=False, desc='1) estimating')): batch_size = min(ne_bin_size, arr.shape[0] - batch) filtered = signal.filtfilt( b, a, arr[batch:batch + batch_size, :].astype( np.double), axis=0) * microvolt_factor for ch in range(num_channels): ne[n, ch] = np.median(abs(filtered[:, ch]) * nfac) return ne
def extract_waveforms(timestamps, arr, outpath, s_pre=10, s_post=22, lc=300, hc=6000, chunk_size_s=60, chunk_overlap_s=0.05, fs=3e4): """Extracts waveforms from raw signal around s_pre->s_post samples of spike trough. Waveforms and timestamps are stored directly in .mat files. """ assert max(timestamps) + s_post < arr.shape[0] assert min(timestamps) - s_pre >= 0 if s_pre + s_post != 32: logger.warning(f'Number of waveforms samples {s_pre}+{s_post} != 32 as expected by MClust!') chunk_size = int(chunk_size_s * fs) n_samples = s_pre + s_post n_channels = arr.shape[1] b, a = butter_bandpass(lc, hc, fs) # samples to cut around detection (threshold crossing) bc_samples = np.arange(-s_pre, s_post).reshape((1, n_samples)) end = arr.shape[0] chunk_starts = [cs * chunk_size for cs in range(ceil(end / chunk_size))] chunk_overlap = int(chunk_overlap_s * fs) # prepare the mat file if os.path.exists(outpath): raise FileExistsError('Mat file already exists. Exiting.') # TODO: Save additional metadata alongside waveforms, e.g. thresholds, version, original paths h5s.savemat(str(outpath), {'n': len(timestamps), 'index': np.double((timestamps - s_pre) / 3), # convert to MClust time domain 'readme': 'Written by dataman.', # 'original_path': str(path) }, compress=False) n_samples_concat = n_samples * n_channels with h5.File(str(outpath), 'a') as hf: hf.create_dataset('spikes', (n_samples_concat, len(timestamps)), maxshape=(n_samples_concat, None), dtype='int16') for n_chunk, start in enumerate(tqdm(chunk_starts, leave=False, desc='3) extracting')): # limits of core batch chunk b_start = start b_end = min(start + chunk_size, end) # limits of chunk with flanking overlaps o_start = max(0, b_start - chunk_overlap) o_end = min(b_end + chunk_overlap, end) # Relevant timestamps ts_idc = np.where((timestamps >= b_start) & (timestamps < b_end))[0] peaks = timestamps[ts_idc].reshape(-1, 1) - o_start if not len(peaks): continue # Bandpass filter raw signal filtered = signal.filtfilt(b, a, arr[o_start:o_end], axis=0) # Extract waveforms idc = bc_samples + peaks try: hf['spikes'][:, min(ts_idc):max(ts_idc) + 1] = filtered[idc].reshape(-1, n_samples_concat).T except IndexError: logger.error('Spikes out of bounds!') break waveforms = np.array(hf['spikes'], dtype='int16').reshape([s_pre + s_post, n_channels, -1]) return waveforms
def detect_spikes(arr, min_thresholds, max_sd=18, fs=3e4, chunk_size_s=60, chunk_overlap_s=0.05, lc=300, hc=6000, s_pre=10, s_post=22, reject_overlap=16, align='min'): """Given wideband signal, find peaks (minima) in the high-pass filtered signal. Returns a list of curated timestamps to reject duplicates and overlapping spikes. """ # TODO: Interpolation # TODO: Maximum artifact rejection # TODO: Return rejected timestamps chunk_size = int(chunk_size_s * fs) # chunk size for detection microvolt_factor = 0.195 use_thr = min_thresholds / microvolt_factor timestamps = [] crs = [] max_thr = use_thr * max_sd if max_thr is not None or max_thr != 0: logger.warning('Maximum rejection for spike detection not implemented.') # # waveform_chunks = [] # rejections = 0 b, a = butter_bandpass(lc, hc, fs) # samples to cut around detection (threshold crossing) bc_samples = np.arange(-s_pre, s_post).reshape([1, -1]) # Chunks will have partial overlap. # 0 # |o|--- chunk 1 ---|x| # |o|--- chunk 2 ---|x| # |o|--- chunk 3 ---| # end # Spikes with peak in the |x| region will be ignored chunk_overlap = int(chunk_overlap_s * fs) # 50 ms chunk boundary overlap # Gather chunk start and ends end = arr.shape[0] chunk_starts = [cs * chunk_size for cs in range(ceil(end / chunk_size))] for n_chunk, start in enumerate(tqdm(chunk_starts, leave=False, desc='2) detecting')): # limits of core batch chunk b_start = start b_end = min(start + chunk_size, end) # limits of chunk with flanking overlaps o_start = max(0, b_start - chunk_overlap) o_end = min(b_end + chunk_overlap, end) # Bandpass filter raw signal filtered = signal.filtfilt(b, a, arr[o_start:o_end], axis=0) # Merge threshold crossings # TODO: Only merge valid channels! crossings = np.clip(np.sum(filtered < -use_thr, axis=1), 0, 1).astype(np.int8) crossings = signal.medfilt(crossings, 3) # exclude crossings with timestamps too close to boundaries xr_starts = (np.diff(crossings) > 0).nonzero()[0] xr_starts = xr_starts[(xr_starts > s_pre) & (xr_starts < filtered.shape[0] - s_post)] # Warning if no spikes were found. That's suspicious given how we calculate the threshold. n_spikes = xr_starts.shape[0] if not n_spikes: logger.warning(f'No spikes in chunk {n_chunk} @ [{b_start} to {b_end}]') continue # Extract preliminary waveforms # get waveform indices by broadcasting indices of crossings bc_spikes = xr_starts.reshape([n_spikes, 1]) idc = bc_samples + bc_spikes try: wv = filtered[idc] except IndexError: logger.error(f'Spikes out of bounds in chunk {n_chunk} @ [{b_start} to {b_end}]') break # Alignment alignments = np.array([wv_alignment(wv[wv_idx, :], method=align)[0] for wv_idx in range(wv.shape[0])]) wv_starts = xr_starts + alignments + o_start - s_pre # first chunk, no overlap lim_a = 0 if n_chunk == 0 else b_start not_early = wv_starts >= lim_a # last chunk, no overlap lim_b = end if n_chunk == len(chunk_starts) else b_end not_late = wv_starts < lim_b - 32 timestamps.append(wv_starts[not_early * not_late]) crs.append(xr_starts[not_early * not_late] + o_start) timestamps = np.sort(np.concatenate(timestamps)) ts_diff = np.diff(timestamps) > reject_overlap too_close = timestamps.shape[0] - np.sum(ts_diff) logger.warning('{} ({:.1f}%) spikes rejected due to >{} sample overlap'.format( too_close, too_close / timestamps.shape[0] * 100, reject_overlap)) valid_timestamps = timestamps[1:][ts_diff] valid_timestamps = np.insert(valid_timestamps, 0, timestamps[0]) return valid_timestamps
def __init__(self, target_path, n_cols=1, channels=None, start=0, dtype='int16', *args, **kwargs): app.Canvas.__init__(self, title=target_path, keys='interactive', size=(1900, 1000), position=(0, 0), app='pyqt5') self.logger = logging.getLogger(__name__) # Target configuration (format, sampling rate, sizes...) self.target_path = target_path self.logger.debug('Target path: {}'.format(target_path)) self.format = util.detect_format(self.target_path) self.logger.debug('Target module: {}'.format(self.format)) assert self.format is not None self.input_dtype = dtype self.metadata = self._get_target_config(*args, **kwargs) # TODO: Have .dat format join in on the new format fun... if 'HEADER' in self.metadata: self.logger.debug( 'Using legacy .dat file metadata dictionary layout') self.fs = self.metadata['HEADER']['sampling_rate'] self.input_dtype = self.metadata['DTYPE'] self.n_samples_total = int(self.metadata['HEADER']['n_samples']) self.n_channels = self.metadata['CHANNELS']['n_channels'] self.block_size = self.metadata['HEADER']['block_size'] elif 'SUBSETS' in self.metadata: # FIXME: Only traverses first subset. self.logger.debug('Using new style metadata dictionary layout') # get number of channels and sampling rate from first subset first_subset = next(iter(self.metadata['SUBSETS'].values())) self.fs = first_subset['JOINT_HEADERS']['sampling_rate'] self.n_samples_total = int( first_subset['JOINT_HEADERS']['n_samples']) self.n_channels = len(first_subset['FILES']) self.block_size = first_subset['JOINT_HEADERS']['block_size'] else: raise ValueError('Unknown metadata format from target.') self.logger.debug( 'From target: {:.2f} Hz, {} channels, {} samples, dtype={}'.format( self.fs, self.n_channels, self.n_samples_total, self.input_dtype)) self.channel_order = channels # if None: no particular order # 300-6000 Hz Highpass filter self.filter = util.butter_bandpass(300, 6000, self.fs) self.apply_filter = False self.duration_total = util.fmt_time(self.n_samples_total / self.fs) # Buffer to store all the pre-loaded signals self.buf = SharedBuffer.SharedBuffer() self.buffer_length = BUFFER_LENGTH self.buf.initialize(n_channels=self.n_channels, n_samples=self.buffer_length, np_dtype=BUFFER_DTYPE) # Streamer to keep buffer filled self.streamer = None self.stream_queue = Queue() self.start_streaming() # Setup up viewport and viewing state variables # running traces, looks cool, but useless for the most part self.running = False self.dirty = True self.offset = int(start * self.fs / 1024) self.drag_offset = 0 self.n_cols = int(n_cols) self.n_rows = int(math.ceil(self.n_channels / self.n_cols)) self.logger.debug('col/row: {}, buffer_length: {}'.format( (self.n_cols, self.n_rows), self.buffer_length)) # Most of the magic happens in the vertex shader, moving the samples into "position" using # an affine transform based on number of columns and rows for the plot, scaling, etc. self.program = gloo.Program(VERT_SHADER, FRAG_SHADER) self._feed_shaders() gloo.set_viewport(0, 0, *self.physical_size) self._timer = app.Timer('auto', connect=self.on_timer, start=True) gloo.set_state(clear_color='black', blend=True, blend_func=('src_alpha', 'one_minus_src_alpha')) self.show()