def filter_recording(recording_directory, timeseries_out):
    from spikeforest2_utils import AutoRecordingExtractor
    from spikeforest2_utils import writemda32
    import spiketoolkit as st
    rx = AutoRecordingExtractor(recording_directory)
    rx2 = st.preprocessing.bandpass_filter(recording=rx,
                                           freq_min=300,
                                           freq_max=6000,
                                           freq_wid=1000)
    if not writemda32(rx2.get_traces(), timeseries_out):
        raise Exception('Unable to write output file.')
def filter_recording(recobj, freq_min=300, freq_max=6000, freq_wid=1000):
    from spikeforest2_utils import AutoRecordingExtractor
    from spikeforest2_utils import writemda32
    import spiketoolkit as st
    rx = AutoRecordingExtractor(recobj)
    rx2 = st.preprocessing.bandpass_filter(recording=rx, freq_min=freq_min, freq_max=freq_max, freq_wid=freq_wid)
    recobj2 = recobj.copy()
    with hither.TemporaryDirectory() as tmpdir:
        raw_fname = tmpdir + '/raw.mda'
        if not writemda32(rx2.get_traces(), raw_fname):
            raise Exception('Unable to write output file.')
        recobj2['raw'] = ka.store_file(raw_fname)
        return recobj2
示例#3
0
def _mda32_to_base64(X):
    f = io.BytesIO()
    writemda32(X, f)
    return base64.b64encode(f.getvalue()).decode('utf-8')
def main():
    import spikeextractors as se
    from spikeforest2_utils import writemda32, AutoRecordingExtractor
    from sklearn.neighbors import NearestNeighbors
    from sklearn.cross_decomposition import PLSRegression
    import spikeforest_widgets as sw
    sw.init_electron()

    # bandpass filter
    with hither.config(container='default', cache='default_readwrite'):
        recobj2 = filter_recording.run(
            recobj=recobj,
            freq_min=300,
            freq_max=6000,
            freq_wid=1000
        ).retval
    
    detect_threshold = 3
    detect_interval = 200
    detect_interval_reference = 10
    detect_sign = -1
    num_events = 1000
    snippet_len = (200, 200)
    window_frac = 0.3
    num_passes = 20
    npca = 100
    max_t = 30000 * 100
    k = 20
    ncomp = 4
    
    R = AutoRecordingExtractor(recobj2)

    X = R.get_traces()
    
    sig = X.copy()
    if detect_sign < 0:
        sig = -sig
    elif detect_sign == 0:
        sig = np.abs(sig)
    sig = np.max(sig, axis=0)
    noise_level = np.median(np.abs(sig)) / 0.6745  # median absolute deviation (MAD)
    times_reference = detect_on_channel(sig, detect_threshold=noise_level*detect_threshold, detect_interval=detect_interval_reference, detect_sign=1, margin=1000)
    times_reference = times_reference[times_reference <= max_t]
    print(f'Num. reference events = {len(times_reference)}')

    snippets_reference = extract_snippets(X, reference_frames=times_reference, snippet_len=snippet_len)
    tt = np.linspace(-1, 1, snippets_reference.shape[2])
    window0 = np.exp(-tt**2/(2*window_frac**2))
    for j in range(snippets_reference.shape[0]):
        for m in range(snippets_reference.shape[1]):
            snippets_reference[j, m, :] = snippets_reference[j, m, :] * window0
    A_snippets_reference = snippets_reference.reshape(snippets_reference.shape[0], snippets_reference.shape[1] * snippets_reference.shape[2])

    print('PCA...')
    u, s, vh = np.linalg.svd(A_snippets_reference)
    components_reference = vh[0:npca, :].T
    features_reference = A_snippets_reference @ components_reference

    print('Setting up nearest neighbors...')
    nbrs = NearestNeighbors(n_neighbors=k + 1, algorithm='ball_tree').fit(features_reference)

    X_signal = np.zeros((R.get_num_channels(), R.get_num_frames()), dtype=np.float32)

    for passnum in range(num_passes):
        print(f'Pass {passnum}')
        sig = X.copy()
        if detect_sign < 0:
            sig = -sig
        elif detect_sign == 0:
            sig = np.abs(sig)
        sig = np.max(sig, axis=0)
        noise_level = np.median(np.abs(sig)) / 0.6745  # median absolute deviation (MAD)
        times = detect_on_channel(sig, detect_threshold=noise_level*detect_threshold, detect_interval=detect_interval, detect_sign=1, margin=1000)
        times = times[times <= max_t]
        print(f'Number of events: {len(times)}')
        if len(times) == 0:
            break
        snippets = extract_snippets(X, reference_frames=times, snippet_len=snippet_len)
        for j in range(snippets.shape[0]):
            for m in range(snippets.shape[1]):
                snippets[j, m, :] = snippets[j, m, :] * window0
        A_snippets = snippets.reshape(snippets.shape[0], snippets.shape[1] * snippets.shape[2])
        features = A_snippets @ components_reference
        
        print('Finding nearest neighbors...')
        distances, indices = nbrs.kneighbors(features)
        features2 = np.zeros(features.shape, dtype=features.dtype)
        print('PLS regression...')
        for j in range(features.shape[0]):
            print(f'{j+1} of {features.shape[0]}')
            inds0 = np.squeeze(indices[j, :])
            inds0 = inds0[1:] # TODO: it may not always be necessary to exclude the first -- how should we make that decision?
            f_neighbors = features_reference[inds0, :]
            pls = PLSRegression(n_components=ncomp)
            pls.fit(f_neighbors.T, features[j, :].T)
            features2[j, :] = pls.predict(f_neighbors.T).T
        A_snippets_denoised = features2 @ components_reference.T
        
        snippets_denoised = A_snippets_denoised.reshape(snippets.shape)

        for j in range(len(times)):
            t0 = times[j]
            snippet_denoised_0 = np.squeeze(snippets_denoised[j, :, :])
            X_signal[:, t0-snippet_len[0]:t0+snippet_len[1]] = X_signal[:, t0-snippet_len[0]:t0+snippet_len[1]] + snippet_denoised_0
            X[:, t0-snippet_len[0]:t0+snippet_len[1]] = X[:, t0-snippet_len[0]:t0+snippet_len[1]] - snippet_denoised_0

    S = np.concatenate((X_signal, X, R.get_traces()), axis=0)

    with hither.TemporaryDirectory() as tmpdir:
        raw_fname = tmpdir + '/raw.mda'
        writemda32(S, raw_fname)
        sig_recobj = recobj2.copy()
        sig_recobj['raw'] = ka.store_file(raw_fname)
    
    sw.TimeseriesView(recording=AutoRecordingExtractor(sig_recobj)).show()