def filter_recording(recording_directory, timeseries_out): from spikeforest2_utils import AutoRecordingExtractor from spikeforest2_utils import writemda32 import spiketoolkit as st rx = AutoRecordingExtractor(recording_directory) rx2 = st.preprocessing.bandpass_filter(recording=rx, freq_min=300, freq_max=6000, freq_wid=1000) if not writemda32(rx2.get_traces(), timeseries_out): raise Exception('Unable to write output file.')
def filter_recording(recobj, freq_min=300, freq_max=6000, freq_wid=1000): from spikeforest2_utils import AutoRecordingExtractor from spikeforest2_utils import writemda32 import spiketoolkit as st rx = AutoRecordingExtractor(recobj) rx2 = st.preprocessing.bandpass_filter(recording=rx, freq_min=freq_min, freq_max=freq_max, freq_wid=freq_wid) recobj2 = recobj.copy() with hither.TemporaryDirectory() as tmpdir: raw_fname = tmpdir + '/raw.mda' if not writemda32(rx2.get_traces(), raw_fname): raise Exception('Unable to write output file.') recobj2['raw'] = ka.store_file(raw_fname) return recobj2
def _mda32_to_base64(X): f = io.BytesIO() writemda32(X, f) return base64.b64encode(f.getvalue()).decode('utf-8')
def main(): import spikeextractors as se from spikeforest2_utils import writemda32, AutoRecordingExtractor from sklearn.neighbors import NearestNeighbors from sklearn.cross_decomposition import PLSRegression import spikeforest_widgets as sw sw.init_electron() # bandpass filter with hither.config(container='default', cache='default_readwrite'): recobj2 = filter_recording.run( recobj=recobj, freq_min=300, freq_max=6000, freq_wid=1000 ).retval detect_threshold = 3 detect_interval = 200 detect_interval_reference = 10 detect_sign = -1 num_events = 1000 snippet_len = (200, 200) window_frac = 0.3 num_passes = 20 npca = 100 max_t = 30000 * 100 k = 20 ncomp = 4 R = AutoRecordingExtractor(recobj2) X = R.get_traces() sig = X.copy() if detect_sign < 0: sig = -sig elif detect_sign == 0: sig = np.abs(sig) sig = np.max(sig, axis=0) noise_level = np.median(np.abs(sig)) / 0.6745 # median absolute deviation (MAD) times_reference = detect_on_channel(sig, detect_threshold=noise_level*detect_threshold, detect_interval=detect_interval_reference, detect_sign=1, margin=1000) times_reference = times_reference[times_reference <= max_t] print(f'Num. reference events = {len(times_reference)}') snippets_reference = extract_snippets(X, reference_frames=times_reference, snippet_len=snippet_len) tt = np.linspace(-1, 1, snippets_reference.shape[2]) window0 = np.exp(-tt**2/(2*window_frac**2)) for j in range(snippets_reference.shape[0]): for m in range(snippets_reference.shape[1]): snippets_reference[j, m, :] = snippets_reference[j, m, :] * window0 A_snippets_reference = snippets_reference.reshape(snippets_reference.shape[0], snippets_reference.shape[1] * snippets_reference.shape[2]) print('PCA...') u, s, vh = np.linalg.svd(A_snippets_reference) components_reference = vh[0:npca, :].T features_reference = A_snippets_reference @ components_reference print('Setting up nearest neighbors...') nbrs = NearestNeighbors(n_neighbors=k + 1, algorithm='ball_tree').fit(features_reference) X_signal = np.zeros((R.get_num_channels(), R.get_num_frames()), dtype=np.float32) for passnum in range(num_passes): print(f'Pass {passnum}') sig = X.copy() if detect_sign < 0: sig = -sig elif detect_sign == 0: sig = np.abs(sig) sig = np.max(sig, axis=0) noise_level = np.median(np.abs(sig)) / 0.6745 # median absolute deviation (MAD) times = detect_on_channel(sig, detect_threshold=noise_level*detect_threshold, detect_interval=detect_interval, detect_sign=1, margin=1000) times = times[times <= max_t] print(f'Number of events: {len(times)}') if len(times) == 0: break snippets = extract_snippets(X, reference_frames=times, snippet_len=snippet_len) for j in range(snippets.shape[0]): for m in range(snippets.shape[1]): snippets[j, m, :] = snippets[j, m, :] * window0 A_snippets = snippets.reshape(snippets.shape[0], snippets.shape[1] * snippets.shape[2]) features = A_snippets @ components_reference print('Finding nearest neighbors...') distances, indices = nbrs.kneighbors(features) features2 = np.zeros(features.shape, dtype=features.dtype) print('PLS regression...') for j in range(features.shape[0]): print(f'{j+1} of {features.shape[0]}') inds0 = np.squeeze(indices[j, :]) inds0 = inds0[1:] # TODO: it may not always be necessary to exclude the first -- how should we make that decision? f_neighbors = features_reference[inds0, :] pls = PLSRegression(n_components=ncomp) pls.fit(f_neighbors.T, features[j, :].T) features2[j, :] = pls.predict(f_neighbors.T).T A_snippets_denoised = features2 @ components_reference.T snippets_denoised = A_snippets_denoised.reshape(snippets.shape) for j in range(len(times)): t0 = times[j] snippet_denoised_0 = np.squeeze(snippets_denoised[j, :, :]) X_signal[:, t0-snippet_len[0]:t0+snippet_len[1]] = X_signal[:, t0-snippet_len[0]:t0+snippet_len[1]] + snippet_denoised_0 X[:, t0-snippet_len[0]:t0+snippet_len[1]] = X[:, t0-snippet_len[0]:t0+snippet_len[1]] - snippet_denoised_0 S = np.concatenate((X_signal, X, R.get_traces()), axis=0) with hither.TemporaryDirectory() as tmpdir: raw_fname = tmpdir + '/raw.mda' writemda32(S, raw_fname) sig_recobj = recobj2.copy() sig_recobj['raw'] = ka.store_file(raw_fname) sw.TimeseriesView(recording=AutoRecordingExtractor(sig_recobj)).show()