Example #1
0
class TimeseriesView:
    def __init__(self):
        super().__init__()
        self._recording = None
        self._multiscale_recordings = None
        self._segment_size_times_num_channels = 1000000
        self._segment_size = None

    def javascript_state_changed(self, prev_state, state):
        self._set_status('running', 'Running TimeseriesView')
        self._create_efficient_access = state.get('create_efficient_access',
                                                  False)
        if not self._recording:
            self._set_status('running', 'Loading recording')
            recording0 = state.get('recording', None)
            if not recording0:
                self._set_error('Missing: recording')
                return
            try:
                self._recording = AutoRecordingExtractor(recording0)
            except Exception as err:
                traceback.print_exc()
                self._set_error('Problem initiating recording: {}'.format(err))
                return

            self._set_status('running', 'Loading recording data')
            traces0 = self._recording.get_traces(
                channel_ids=self._recording.get_channel_ids(),
                start_frame=0,
                end_frame=min(self._recording.get_num_frames(), 25000))
            y_offsets = -np.mean(traces0, axis=1)
            for m in range(traces0.shape[0]):
                traces0[m, :] = traces0[m, :] + y_offsets[m]
            vv = np.percentile(np.abs(traces0), 90)
            y_scale_factor = 1 / (2 * vv) if vv > 0 else 1
            self._segment_size = int(
                np.ceil(self._segment_size_times_num_channels /
                        self._recording.get_num_channels()))
            try:
                channel_locations = self._recording.get_channel_locations()
            except:
                channel_locations = None
            self.set_state(
                dict(num_channels=self._recording.get_num_channels(),
                     channel_ids=self._recording.get_channel_ids(),
                     channel_locations=channel_locations,
                     num_timepoints=self._recording.get_num_frames(),
                     y_offsets=y_offsets,
                     y_scale_factor=y_scale_factor,
                     samplerate=self._recording.get_sampling_frequency(),
                     segment_size=self._segment_size,
                     status_message='Loaded recording.'))

        # SR = state.get('segmentsRequested', {})
        # for key in SR.keys():
        #     aa = SR[key]
        #     if not self.get_python_state(key, None):
        #         self.set_state(dict(status_message='Loading segment {}'.format(key)))
        #         data0 = self._load_data(aa['ds'], aa['ss'])
        #         data0_base64 = _mda32_to_base64(data0)
        #         state0 = {}
        #         state0[key] = dict(data=data0_base64, ds=aa['ds'], ss=aa['ss'])
        #         self.set_state(state0)
        #         self.set_state(dict(status_message='Loaded segment {}'.format(key)))
        self._set_status('finished', '')

    def on_message(self, msg):

        if msg['command'] == 'requestSegment':
            ds = msg['ds_factor']
            ss = msg['segment_num']
            data0 = self._load_data(ds, ss)
            data0_base64 = _mda32_to_base64(data0)
            self.send_message(
                dict(command='setSegment',
                     ds_factor=ds,
                     segment_num=ss,
                     data=data0_base64))

    def _load_data(self, ds, ss):
        if not self._recording:
            return
        logger.info('_load_data {} {}'.format(ds, ss))
        if ds > 1:
            if self._multiscale_recordings is None:
                self.set_state(
                    dict(status_message='Creating multiscale recordings...'))
                self._multiscale_recordings = _create_multiscale_recordings(
                    recording=self._recording,
                    progressive_ds_factor=3,
                    create_efficient_access=self._create_efficient_access)
                self.set_state(
                    dict(status_message='Done creating multiscale recording'))
            rx = self._multiscale_recordings[ds]
            # print('_extract_data_segment', ds, ss, self._segment_size)
            start_time = time.time()
            X = _extract_data_segment(recording=rx,
                                      segment_num=ss,
                                      segment_size=self._segment_size * 2)
            # print('done extracting data segment', time.time() - start_time)
            logger.info('extracted data segment {} {} {}'.format(
                ds, ss,
                time.time() - start_time))
            return X

        start_time = time.time()
        traces = self._recording.get_traces(
            start_frame=ss * self._segment_size,
            end_frame=(ss + 1) * self._segment_size)
        logger.info('extracted data segment {} {} {}'.format(
            ds, ss,
            time.time() - start_time))
        return traces

    def iterate(self):
        pass

    def _set_state(self, **kwargs):
        self.set_state(kwargs)

    def _set_error(self, error_message):
        self._set_status('error', error_message)

    def _set_status(self, status, status_message=''):
        self._set_state(status=status, status_message=status_message)
Example #2
0
detect_threshold = 5  #As obtained in literature
detect_interval = 200
detect_interval_reference = 10
detect_sign = -1
num_events = 1000
snippet_len = (200, 200)
window_frac = 0.3
num_passes = 20
npca = 100
max_t = 30000 * 100
k = 20
ncomp = 4

R = AutoRecordingExtractor(rx2)

X = R.get_traces()  #getting the snippets of data

sig = X.copy()
if detect_sign < 0:
    sig = -sig
elif detect_sign == 0:
    sig = np.abs(sig)
sig = np.max(sig, axis=0)
noise_level = np.median(
    np.abs(sig)) / 0.6745  # median absolute deviation (MAD)
times_reference = detect_on_channel(sig,
                                    detect_threshold=noise_level *
                                    detect_threshold,
                                    detect_interval=detect_interval_reference,
                                    detect_sign=1,
                                    margin=1000)
def main():
    import spikeextractors as se
    from spikeforest2_utils import writemda32, AutoRecordingExtractor
    from sklearn.neighbors import NearestNeighbors
    from sklearn.cross_decomposition import PLSRegression
    import spikeforest_widgets as sw
    sw.init_electron()

    # bandpass filter
    with hither.config(container='default', cache='default_readwrite'):
        recobj2 = filter_recording.run(
            recobj=recobj,
            freq_min=300,
            freq_max=6000,
            freq_wid=1000
        ).retval
    
    detect_threshold = 3
    detect_interval = 200
    detect_interval_reference = 10
    detect_sign = -1
    num_events = 1000
    snippet_len = (200, 200)
    window_frac = 0.3
    num_passes = 20
    npca = 100
    max_t = 30000 * 100
    k = 20
    ncomp = 4
    
    R = AutoRecordingExtractor(recobj2)

    X = R.get_traces()
    
    sig = X.copy()
    if detect_sign < 0:
        sig = -sig
    elif detect_sign == 0:
        sig = np.abs(sig)
    sig = np.max(sig, axis=0)
    noise_level = np.median(np.abs(sig)) / 0.6745  # median absolute deviation (MAD)
    times_reference = detect_on_channel(sig, detect_threshold=noise_level*detect_threshold, detect_interval=detect_interval_reference, detect_sign=1, margin=1000)
    times_reference = times_reference[times_reference <= max_t]
    print(f'Num. reference events = {len(times_reference)}')

    snippets_reference = extract_snippets(X, reference_frames=times_reference, snippet_len=snippet_len)
    tt = np.linspace(-1, 1, snippets_reference.shape[2])
    window0 = np.exp(-tt**2/(2*window_frac**2))
    for j in range(snippets_reference.shape[0]):
        for m in range(snippets_reference.shape[1]):
            snippets_reference[j, m, :] = snippets_reference[j, m, :] * window0
    A_snippets_reference = snippets_reference.reshape(snippets_reference.shape[0], snippets_reference.shape[1] * snippets_reference.shape[2])

    print('PCA...')
    u, s, vh = np.linalg.svd(A_snippets_reference)
    components_reference = vh[0:npca, :].T
    features_reference = A_snippets_reference @ components_reference

    print('Setting up nearest neighbors...')
    nbrs = NearestNeighbors(n_neighbors=k + 1, algorithm='ball_tree').fit(features_reference)

    X_signal = np.zeros((R.get_num_channels(), R.get_num_frames()), dtype=np.float32)

    for passnum in range(num_passes):
        print(f'Pass {passnum}')
        sig = X.copy()
        if detect_sign < 0:
            sig = -sig
        elif detect_sign == 0:
            sig = np.abs(sig)
        sig = np.max(sig, axis=0)
        noise_level = np.median(np.abs(sig)) / 0.6745  # median absolute deviation (MAD)
        times = detect_on_channel(sig, detect_threshold=noise_level*detect_threshold, detect_interval=detect_interval, detect_sign=1, margin=1000)
        times = times[times <= max_t]
        print(f'Number of events: {len(times)}')
        if len(times) == 0:
            break
        snippets = extract_snippets(X, reference_frames=times, snippet_len=snippet_len)
        for j in range(snippets.shape[0]):
            for m in range(snippets.shape[1]):
                snippets[j, m, :] = snippets[j, m, :] * window0
        A_snippets = snippets.reshape(snippets.shape[0], snippets.shape[1] * snippets.shape[2])
        features = A_snippets @ components_reference
        
        print('Finding nearest neighbors...')
        distances, indices = nbrs.kneighbors(features)
        features2 = np.zeros(features.shape, dtype=features.dtype)
        print('PLS regression...')
        for j in range(features.shape[0]):
            print(f'{j+1} of {features.shape[0]}')
            inds0 = np.squeeze(indices[j, :])
            inds0 = inds0[1:] # TODO: it may not always be necessary to exclude the first -- how should we make that decision?
            f_neighbors = features_reference[inds0, :]
            pls = PLSRegression(n_components=ncomp)
            pls.fit(f_neighbors.T, features[j, :].T)
            features2[j, :] = pls.predict(f_neighbors.T).T
        A_snippets_denoised = features2 @ components_reference.T
        
        snippets_denoised = A_snippets_denoised.reshape(snippets.shape)

        for j in range(len(times)):
            t0 = times[j]
            snippet_denoised_0 = np.squeeze(snippets_denoised[j, :, :])
            X_signal[:, t0-snippet_len[0]:t0+snippet_len[1]] = X_signal[:, t0-snippet_len[0]:t0+snippet_len[1]] + snippet_denoised_0
            X[:, t0-snippet_len[0]:t0+snippet_len[1]] = X[:, t0-snippet_len[0]:t0+snippet_len[1]] - snippet_denoised_0

    S = np.concatenate((X_signal, X, R.get_traces()), axis=0)

    with hither.TemporaryDirectory() as tmpdir:
        raw_fname = tmpdir + '/raw.mda'
        writemda32(S, raw_fname)
        sig_recobj = recobj2.copy()
        sig_recobj['raw'] = ka.store_file(raw_fname)
    
    sw.TimeseriesView(recording=AutoRecordingExtractor(sig_recobj)).show()