class Recording: def __init__(self): super().__init__() self._recording = None def javascript_state_changed(self, prev_state, state): self._set_status('running', 'Running Recording') if not self._recording: self._set_status('running', 'Loading recording') recording0 = state.get('recording', None) if not recording0: self._set_error('Missing: recording') return try: self._recording = AutoRecordingExtractor(recording0) except Exception as err: traceback.print_exc() self._set_error('Problem initiating recording: {}'.format(err)) return self._set_status('running', 'Loading recording data') try: channel_locations = self._recording.get_channel_locations() except: channel_locations = None self.set_state(dict( num_channels=self._recording.get_num_channels(), channel_ids=self._recording.get_channel_ids(), channel_locations=channel_locations, num_timepoints=self._recording.get_num_frames(), samplerate=self._recording.get_sampling_frequency(), status_message='Loaded recording.' )) self._set_status('finished', '') def _set_state(self, **kwargs): self.set_state(kwargs) def _set_error(self, error_message): self._set_status('error', error_message) def _set_status(self, status, status_message=''): self._set_state(status=status, status_message=status_message)
class TimeseriesView: def __init__(self): super().__init__() self._recording = None self._multiscale_recordings = None self._segment_size_times_num_channels = 1000000 self._segment_size = None def javascript_state_changed(self, prev_state, state): self._set_status('running', 'Running TimeseriesView') self._create_efficient_access = state.get('create_efficient_access', False) if not self._recording: self._set_status('running', 'Loading recording') recording0 = state.get('recording', None) if not recording0: self._set_error('Missing: recording') return try: self._recording = AutoRecordingExtractor(recording0) except Exception as err: traceback.print_exc() self._set_error('Problem initiating recording: {}'.format(err)) return self._set_status('running', 'Loading recording data') traces0 = self._recording.get_traces( channel_ids=self._recording.get_channel_ids(), start_frame=0, end_frame=min(self._recording.get_num_frames(), 25000)) y_offsets = -np.mean(traces0, axis=1) for m in range(traces0.shape[0]): traces0[m, :] = traces0[m, :] + y_offsets[m] vv = np.percentile(np.abs(traces0), 90) y_scale_factor = 1 / (2 * vv) if vv > 0 else 1 self._segment_size = int( np.ceil(self._segment_size_times_num_channels / self._recording.get_num_channels())) try: channel_locations = self._recording.get_channel_locations() except: channel_locations = None self.set_state( dict(num_channels=self._recording.get_num_channels(), channel_ids=self._recording.get_channel_ids(), channel_locations=channel_locations, num_timepoints=self._recording.get_num_frames(), y_offsets=y_offsets, y_scale_factor=y_scale_factor, samplerate=self._recording.get_sampling_frequency(), segment_size=self._segment_size, status_message='Loaded recording.')) # SR = state.get('segmentsRequested', {}) # for key in SR.keys(): # aa = SR[key] # if not self.get_python_state(key, None): # self.set_state(dict(status_message='Loading segment {}'.format(key))) # data0 = self._load_data(aa['ds'], aa['ss']) # data0_base64 = _mda32_to_base64(data0) # state0 = {} # state0[key] = dict(data=data0_base64, ds=aa['ds'], ss=aa['ss']) # self.set_state(state0) # self.set_state(dict(status_message='Loaded segment {}'.format(key))) self._set_status('finished', '') def on_message(self, msg): if msg['command'] == 'requestSegment': ds = msg['ds_factor'] ss = msg['segment_num'] data0 = self._load_data(ds, ss) data0_base64 = _mda32_to_base64(data0) self.send_message( dict(command='setSegment', ds_factor=ds, segment_num=ss, data=data0_base64)) def _load_data(self, ds, ss): if not self._recording: return logger.info('_load_data {} {}'.format(ds, ss)) if ds > 1: if self._multiscale_recordings is None: self.set_state( dict(status_message='Creating multiscale recordings...')) self._multiscale_recordings = _create_multiscale_recordings( recording=self._recording, progressive_ds_factor=3, create_efficient_access=self._create_efficient_access) self.set_state( dict(status_message='Done creating multiscale recording')) rx = self._multiscale_recordings[ds] # print('_extract_data_segment', ds, ss, self._segment_size) start_time = time.time() X = _extract_data_segment(recording=rx, segment_num=ss, segment_size=self._segment_size * 2) # print('done extracting data segment', time.time() - start_time) logger.info('extracted data segment {} {} {}'.format( ds, ss, time.time() - start_time)) return X start_time = time.time() traces = self._recording.get_traces( start_frame=ss * self._segment_size, end_frame=(ss + 1) * self._segment_size) logger.info('extracted data segment {} {} {}'.format( ds, ss, time.time() - start_time)) return traces def iterate(self): pass def _set_state(self, **kwargs): self.set_state(kwargs) def _set_error(self, error_message): self._set_status('error', error_message) def _set_status(self, status, status_message=''): self._set_state(status=status, status_message=status_message)
def main(): import spikeextractors as se from spikeforest2_utils import writemda32, AutoRecordingExtractor from sklearn.neighbors import NearestNeighbors from sklearn.cross_decomposition import PLSRegression import spikeforest_widgets as sw sw.init_electron() # bandpass filter with hither.config(container='default', cache='default_readwrite'): recobj2 = filter_recording.run( recobj=recobj, freq_min=300, freq_max=6000, freq_wid=1000 ).retval detect_threshold = 3 detect_interval = 200 detect_interval_reference = 10 detect_sign = -1 num_events = 1000 snippet_len = (200, 200) window_frac = 0.3 num_passes = 20 npca = 100 max_t = 30000 * 100 k = 20 ncomp = 4 R = AutoRecordingExtractor(recobj2) X = R.get_traces() sig = X.copy() if detect_sign < 0: sig = -sig elif detect_sign == 0: sig = np.abs(sig) sig = np.max(sig, axis=0) noise_level = np.median(np.abs(sig)) / 0.6745 # median absolute deviation (MAD) times_reference = detect_on_channel(sig, detect_threshold=noise_level*detect_threshold, detect_interval=detect_interval_reference, detect_sign=1, margin=1000) times_reference = times_reference[times_reference <= max_t] print(f'Num. reference events = {len(times_reference)}') snippets_reference = extract_snippets(X, reference_frames=times_reference, snippet_len=snippet_len) tt = np.linspace(-1, 1, snippets_reference.shape[2]) window0 = np.exp(-tt**2/(2*window_frac**2)) for j in range(snippets_reference.shape[0]): for m in range(snippets_reference.shape[1]): snippets_reference[j, m, :] = snippets_reference[j, m, :] * window0 A_snippets_reference = snippets_reference.reshape(snippets_reference.shape[0], snippets_reference.shape[1] * snippets_reference.shape[2]) print('PCA...') u, s, vh = np.linalg.svd(A_snippets_reference) components_reference = vh[0:npca, :].T features_reference = A_snippets_reference @ components_reference print('Setting up nearest neighbors...') nbrs = NearestNeighbors(n_neighbors=k + 1, algorithm='ball_tree').fit(features_reference) X_signal = np.zeros((R.get_num_channels(), R.get_num_frames()), dtype=np.float32) for passnum in range(num_passes): print(f'Pass {passnum}') sig = X.copy() if detect_sign < 0: sig = -sig elif detect_sign == 0: sig = np.abs(sig) sig = np.max(sig, axis=0) noise_level = np.median(np.abs(sig)) / 0.6745 # median absolute deviation (MAD) times = detect_on_channel(sig, detect_threshold=noise_level*detect_threshold, detect_interval=detect_interval, detect_sign=1, margin=1000) times = times[times <= max_t] print(f'Number of events: {len(times)}') if len(times) == 0: break snippets = extract_snippets(X, reference_frames=times, snippet_len=snippet_len) for j in range(snippets.shape[0]): for m in range(snippets.shape[1]): snippets[j, m, :] = snippets[j, m, :] * window0 A_snippets = snippets.reshape(snippets.shape[0], snippets.shape[1] * snippets.shape[2]) features = A_snippets @ components_reference print('Finding nearest neighbors...') distances, indices = nbrs.kneighbors(features) features2 = np.zeros(features.shape, dtype=features.dtype) print('PLS regression...') for j in range(features.shape[0]): print(f'{j+1} of {features.shape[0]}') inds0 = np.squeeze(indices[j, :]) inds0 = inds0[1:] # TODO: it may not always be necessary to exclude the first -- how should we make that decision? f_neighbors = features_reference[inds0, :] pls = PLSRegression(n_components=ncomp) pls.fit(f_neighbors.T, features[j, :].T) features2[j, :] = pls.predict(f_neighbors.T).T A_snippets_denoised = features2 @ components_reference.T snippets_denoised = A_snippets_denoised.reshape(snippets.shape) for j in range(len(times)): t0 = times[j] snippet_denoised_0 = np.squeeze(snippets_denoised[j, :, :]) X_signal[:, t0-snippet_len[0]:t0+snippet_len[1]] = X_signal[:, t0-snippet_len[0]:t0+snippet_len[1]] + snippet_denoised_0 X[:, t0-snippet_len[0]:t0+snippet_len[1]] = X[:, t0-snippet_len[0]:t0+snippet_len[1]] - snippet_denoised_0 S = np.concatenate((X_signal, X, R.get_traces()), axis=0) with hither.TemporaryDirectory() as tmpdir: raw_fname = tmpdir + '/raw.mda' writemda32(S, raw_fname) sig_recobj = recobj2.copy() sig_recobj['raw'] = ka.store_file(raw_fname) sw.TimeseriesView(recording=AutoRecordingExtractor(sig_recobj)).show()