def javascript_state_changed(self, prev_state, state): self._set_status('running', 'Running Recording') if not self._recording: self._set_status('running', 'Loading recording') recording0 = state.get('recording', None) if not recording0: self._set_error('Missing: recording') return try: self._recording = AutoRecordingExtractor(recording0) except Exception as err: traceback.print_exc() self._set_error('Problem initiating recording: {}'.format(err)) return self._set_status('running', 'Loading recording data') try: channel_locations = self._recording.get_channel_locations() except: channel_locations = None self.set_state(dict( num_channels=self._recording.get_num_channels(), channel_ids=self._recording.get_channel_ids(), channel_locations=channel_locations, num_timepoints=self._recording.get_num_frames(), samplerate=self._recording.get_sampling_frequency(), status_message='Loaded recording.' )) self._set_status('finished', '')
def javascript_state_changed(self, prev_state, state): self._set_status('running', 'Running TimeseriesView') self._create_efficient_access = state.get('create_efficient_access', False) if not self._recording: self._set_status('running', 'Loading recording') recording0 = state.get('recording', None) if not recording0: self._set_error('Missing: recording') return try: self._recording = AutoRecordingExtractor(recording0) except Exception as err: traceback.print_exc() self._set_error('Problem initiating recording: {}'.format(err)) return self._set_status('running', 'Loading recording data') traces0 = self._recording.get_traces( channel_ids=self._recording.get_channel_ids(), start_frame=0, end_frame=min(self._recording.get_num_frames(), 25000)) y_offsets = -np.mean(traces0, axis=1) for m in range(traces0.shape[0]): traces0[m, :] = traces0[m, :] + y_offsets[m] vv = np.percentile(np.abs(traces0), 90) y_scale_factor = 1 / (2 * vv) if vv > 0 else 1 self._segment_size = int( np.ceil(self._segment_size_times_num_channels / self._recording.get_num_channels())) try: channel_locations = self._recording.get_channel_locations() except: channel_locations = None self.set_state( dict(num_channels=self._recording.get_num_channels(), channel_ids=self._recording.get_channel_ids(), channel_locations=channel_locations, num_timepoints=self._recording.get_num_frames(), y_offsets=y_offsets, y_scale_factor=y_scale_factor, samplerate=self._recording.get_sampling_frequency(), segment_size=self._segment_size, status_message='Loaded recording.')) # SR = state.get('segmentsRequested', {}) # for key in SR.keys(): # aa = SR[key] # if not self.get_python_state(key, None): # self.set_state(dict(status_message='Loading segment {}'.format(key))) # data0 = self._load_data(aa['ds'], aa['ss']) # data0_base64 = _mda32_to_base64(data0) # state0 = {} # state0[key] = dict(data=data0_base64, ds=aa['ds'], ss=aa['ss']) # self.set_state(state0) # self.set_state(dict(status_message='Loaded segment {}'.format(key))) self._set_status('finished', '')
def herdingspikes2(recording_path, sorting_out, filter=True, pre_scale=True, pre_scale_value=20): import spiketoolkit as st import spikesorters as ss from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor recording = AutoRecordingExtractor(dict(path=recording_path), download=True) # Sorting print('Sorting...') output_folder = '/tmp/tmp_herdingspikes2_' + _random_string(8) os.environ[ 'HS2_PROBE_PATH'] = output_folder # important for when we are in a container sorter = ss.HerdingspikesSorter(recording=recording, output_folder=output_folder, delete_output_folder=True) num_workers = os.environ.get('NUM_WORKERS', None) if not num_workers: num_workers = '1' num_workers = int(num_workers) sorter.set_params(filter=filter, pre_scale=pre_scale, pre_scale_value=pre_scale_value, clustering_n_jobs=num_workers) timer = sorter.run() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer)) sorting = sorter.get_result() AutoSortingExtractor.write_sorting(sorting=sorting, save_path=sorting_out)
def waveclus( recording_path, sorting_out ): from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor from ._waveclussorter import WaveclusSorter recording = AutoRecordingExtractor(dict(path=recording_path), download=True) # recording = se.SubRecordingExtractor(parent_recording=recording, start_frame=0, end_frame=30000 * 10) # Sorting print('Sorting...') sorter = WaveclusSorter( recording=recording, output_folder='/tmp/tmp_waveclus_' + _random_string(8), delete_output_folder=True ) sorter.set_params( ) timer = sorter.run() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer)) sorting = sorter.get_result() AutoSortingExtractor.write_sorting(sorting=sorting, save_path=sorting_out)
def kilosort2(recording, sorting_out): from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor from ._kilosort2sorter import Kilosort2Sorter import kachery as ka # TODO: need to think about how to deal with this ka.set_config(fr='default_readonly') recording = AutoRecordingExtractor(dict(path=recording), download=True) # recording = se.SubRecordingExtractor(parent_recording=recording, start_frame=0, end_frame=30000 * 10) # Sorting print('Sorting...') sorter = Kilosort2Sorter( recording=recording, output_folder='/tmp/tmp_kilosort2_' + _random_string(8), delete_output_folder=True ) sorter.set_params( detect_sign=-1, detect_threshold=5, freq_min=150, pc_per_chan=3 ) timer = sorter.run() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer)) sorting = sorter.get_result() AutoSortingExtractor.write_sorting(sorting=sorting, save_path=sorting_out)
def tridesclous( recording_path, sorting_out ): import spiketoolkit as st import spikesorters as ss from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor recording = AutoRecordingExtractor(dict(path=recording_path), download=True) # Sorting print('Sorting...') output_folder = '/tmp/tmp_tridesclous_' + _random_string(8) os.environ['HS2_PROBE_PATH'] = output_folder # important for when we are in a container sorter = ss.TridesclousSorter( recording=recording, output_folder=output_folder, delete_output_folder=True, verbose=True, ) # num_workers = os.environ.get('NUM_WORKERS', None) # if not num_workers: num_workers='1' # num_workers = int(num_workers) sorter.set_params( ) timer = sorter.run() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer)) sorting = sorter.get_result() AutoSortingExtractor.write_sorting(sorting=sorting, save_path=sorting_out)
def mountainsort4(recording: str, sorting_out: str) -> str: import spiketoolkit as st import spikesorters as ss from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor import kachery as ka # TODO: need to think about how to deal with this ka.set_config(fr='default_readonly') recording = AutoRecordingExtractor(dict(path=recording), download=True) # recording = se.SubRecordingExtractor(parent_recording=recording, start_frame=0, end_frame=30000 * 10) # Preprocessing print('Preprocessing...') recording = st.preprocessing.bandpass_filter(recording, freq_min=300, freq_max=6000) recording = st.preprocessing.whiten(recording) # Sorting print('Sorting...') sorter = ss.Mountainsort4Sorter(recording=recording, output_folder='/tmp/tmp_mountainsort4_' + _random_string(8), delete_output_folder=True) sorter.set_params(detect_sign=-1, adjacency_radius=50, detect_threshold=4) timer = sorter.run() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer)) sorting = sorter.get_result() AutoSortingExtractor.write_sorting(sorting=sorting, save_path=sorting_out)
def kilosort( recording_path, sorting_out, detect_threshold=6, freq_min=300, freq_max=6000, Nt=128 * 1024 * 5 + 64 # batch size for kilosort ): from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor from ._kilosortsorter import KilosortSorter recording = AutoRecordingExtractor(dict(path=recording_path), download=True) # recording = se.SubRecordingExtractor(parent_recording=recording, start_frame=0, end_frame=30000 * 10) # Sorting print('Sorting...') sorter = KilosortSorter(recording=recording, output_folder='/tmp/tmp_kilosort_' + _random_string(8), delete_output_folder=True) sorter.set_params(detect_threshold=detect_threshold, freq_min=freq_min, freq_max=freq_max, car=True) timer = sorter.run() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer)) sorting = sorter.get_result() AutoSortingExtractor.write_sorting(sorting=sorting, save_path=sorting_out)
def mountainsort4( recording_path: str, sorting_out: str, detect_sign=-1, adjacency_radius=50, clip_size=50, detect_threshold=3, detect_interval=10, freq_min=300, freq_max=6000 ): import spiketoolkit as st import spikesorters as ss from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor recording = AutoRecordingExtractor(dict(path=recording_path), download=True) # for quick testing # import spikeextractors as se # recording = se.SubRecordingExtractor(parent_recording=recording, start_frame=0, end_frame=30000 * 1) # Preprocessing # print('Preprocessing...') # recording = st.preprocessing.bandpass_filter(recording, freq_min=300, freq_max=6000) # recording = st.preprocessing.whiten(recording) # Sorting print('Sorting...') sorter = ss.Mountainsort4Sorter( recording=recording, output_folder='/tmp/tmp_mountainsort4_' + _random_string(8), delete_output_folder=True ) num_workers = os.environ.get('NUM_WORKERS', None) if num_workers: num_workers = int(num_workers) else: num_workers = 0 sorter.set_params( detect_sign=detect_sign, adjacency_radius=adjacency_radius, clip_size=clip_size, detect_threshold=detect_threshold, detect_interval=detect_interval, num_workers=num_workers, curation=False, whiten=True, filter=True, freq_min=freq_min, freq_max=freq_max ) timer = sorter.run() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer)) sorting = sorter.get_result() AutoSortingExtractor.write_sorting(sorting=sorting, save_path=sorting_out)
def filter_recording(recording_directory, timeseries_out): from spikeforest2_utils import AutoRecordingExtractor from spikeforest2_utils import writemda32 import spiketoolkit as st rx = AutoRecordingExtractor(recording_directory) rx2 = st.preprocessing.bandpass_filter(recording=rx, freq_min=300, freq_max=6000, freq_wid=1000) if not writemda32(rx2.get_traces(), timeseries_out): raise Exception('Unable to write output file.')
def jrclust( recording_path, sorting_out, detect_sign=-1, # Use -1, 0, or 1, depending on the sign of the spikes in the recording') adjacency_radius=50, detect_threshold=4.5, # detection threshold freq_min=300, freq_max=3000, merge_thresh=0.98, pc_per_chan=1, filter_type='bandpass', # {none, bandpass, wiener, fftdiff, ndiff} nDiffOrder='none', min_count=30, fGpu=0, fParfor=0, feature_type='gpca' # # {gpca, pca, vpp, vmin, vminmax, cov, energy, xcov}') ): from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor from ._jrclustsorter import JRClustSorter recording = AutoRecordingExtractor(dict(path=recording_path), download=True) # recording = se.SubRecordingExtractor(parent_recording=recording, start_frame=0, end_frame=30000 * 10) # Sorting print('Sorting...') sorter = JRClustSorter( recording=recording, output_folder='/tmp/tmp_jrclust_' + _random_string(8), delete_output_folder=True ) sorter.set_params( detect_sign=detect_sign, adjacency_radius=adjacency_radius, detect_threshold=detect_threshold, freq_min=freq_min, freq_max=freq_max, merge_thresh=merge_thresh, pc_per_chan=pc_per_chan, filter_type=filter_type, nDiffOrder=nDiffOrder, min_count=min_count, fGpu=fGpu, fParfor=fParfor, feature_type=feature_type ) timer = sorter.run() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer)) sorting = sorter.get_result() AutoSortingExtractor.write_sorting(sorting=sorting, save_path=sorting_out)
def filter_recording(recobj, freq_min=300, freq_max=6000, freq_wid=1000): from spikeforest2_utils import AutoRecordingExtractor from spikeforest2_utils import writemda32 import spiketoolkit as st rx = AutoRecordingExtractor(recobj) rx2 = st.preprocessing.bandpass_filter(recording=rx, freq_min=freq_min, freq_max=freq_max, freq_wid=freq_wid) recobj2 = recobj.copy() with hither.TemporaryDirectory() as tmpdir: raw_fname = tmpdir + '/raw.mda' if not writemda32(rx2.get_traces(), raw_fname): raise Exception('Unable to write output file.') recobj2['raw'] = ka.store_file(raw_fname) return recobj2
def compute_recording_info(recording_path, json_out): recording = AutoRecordingExtractor(recording_path) obj = dict(samplerate=recording.get_sampling_frequency(), num_channels=len(recording.get_channel_ids()), duration_sec=recording.get_num_frames() / recording.get_sampling_frequency()) with open(json_out, 'w') as f: json.dump(obj, f)
def ironclust(recording, sorting_out): from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor from ._ironclustsorter import IronClustSorter import kachery as ka # TODO: need to think about how to deal with this ka.set_config(fr='default_readonly') recording = AutoRecordingExtractor(dict(path=recording), download=True) # Sorting print('Sorting...') sorter = IronClustSorter(recording=recording, output_folder='/tmp/tmp_ironclust_' + _random_string(8), delete_output_folder=True) sorter.set_params(detect_sign=-1, adjacency_radius=50, adjacency_radius_out=75, detect_threshold=4, prm_template_name='', freq_min=300, freq_max=8000, merge_thresh=0.99, pc_per_chan=0, whiten=False, filter_type='bandpass', filter_detect_type='none', common_ref_type='mean', batch_sec_drift=300, step_sec_drift=20, knn=30, min_count=30, fGpu=True, fft_thresh=8, fft_thresh_low=0, nSites_whiten=32, feature_type='gpca', delta_cut=1, post_merge_mode=1, sort_mode=1) timer = sorter.run() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer)) sorting = sorter.get_result() AutoSortingExtractor.write_sorting(sorting=sorting, save_path=sorting_out)
def spykingcircus(recording_path, sorting_out, detect_sign=-1, adjacency_radius=200, detect_threshold=6, template_width_ms=3, filter=True, merge_spikes=True, auto_merge=0.75, whitening_max_elts=1000, clustering_max_elts=10000): import spiketoolkit as st import spikesorters as ss from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor recording = AutoRecordingExtractor(dict(path=recording_path), download=True) # Sorting print('Sorting...') sorter = ss.SpykingcircusSorter(recording=recording, output_folder='/tmp/tmp_spykingcircus_' + _random_string(8), delete_output_folder=True) num_workers = os.environ.get('NUM_WORKERS', None) if not num_workers: num_workers = '1' num_workers = int(num_workers) sorter.set_params(detect_sign=detect_sign, adjacency_radius=adjacency_radius, detect_threshold=detect_threshold, template_width_ms=template_width_ms, filter=filter, merge_spikes=merge_spikes, auto_merge=auto_merge, num_workers=num_workers, whitening_max_elts=whitening_max_elts, clustering_max_elts=clustering_max_elts) timer = sorter.run() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer)) sorting = sorter.get_result() AutoSortingExtractor.write_sorting(sorting=sorting, save_path=sorting_out)
def klusta( recording_path, sorting_out, adjacency_radius=None, detect_sign=-1, threshold_strong_std_factor=5, threshold_weak_std_factor=2, n_features_per_channel=3, num_starting_clusters=3, extract_s_before=16, extract_s_after=32 ): import spikesorters as ss from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor recording = AutoRecordingExtractor(dict(path=recording_path), download=True) # Sorting print('Sorting...') sorter = ss.KlustaSorter( recording=recording, output_folder='/tmp/tmp_klusta_' + _random_string(8), delete_output_folder=True ) # num_workers = os.environ.get('NUM_WORKERS', None) # if not num_workers: num_workers='1' # num_workers = int(num_workers) sorter.set_params( adjacency_radius=adjacency_radius, detect_sign=detect_sign, threshold_strong_std_factor=threshold_strong_std_factor, threshold_weak_std_factor=threshold_weak_std_factor, n_features_per_channel=n_features_per_channel, num_starting_clusters=num_starting_clusters, extract_s_before=extract_s_before, extract_s_after=extract_s_after ) timer = sorter.run() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer)) sorting = sorter.get_result() AutoSortingExtractor.write_sorting(sorting=sorting, save_path=sorting_out)
def prepare_dataset_from_hash( recording_paths: Union[str, List[str]], gt_paths: Union[str, List[str]], sorter_names: List[str], metric_names: List[str], cache_path: Path, ): if isinstance(recording_paths, str): recording_paths = [recording_paths] if isinstance(gt_paths, str): gt_paths = [gt_paths] if len(recording_paths) != len(gt_paths): raise ValueError( f"You have provided {len(recording_paths)} recording hashes and {len(gt_paths)} ground truth hashes! These must be the same." ) all_X = [] all_y = [] for i in range(len(recording_paths)): recording_path = recording_paths[i] gt_path = gt_paths[i] c_path = cache_path / recording_path.split('//')[1] recording = AutoRecordingExtractor(recording_path, download=True) gt_sorting = AutoSortingExtractor(gt_path) session = SpikeSession(recording, gt_sorting, cache_path=c_path) X, y = prepare_fp_dataset(session, sorter_names=sorter_names, metric_names=metric_names) all_X.append(X) all_y.append(y) return np.vstack(all_X), np.hstack(all_y)
def load_spikeforest_data(recording_path: str, sorting_true_path: str, download=True): recording = AutoRecordingExtractor(recording_path, download=download) sorting_GT = AutoSortingExtractor(sorting_true_path) # recording info fs = recording.get_sampling_frequency() channel_ids = recording.get_channel_ids() channel_loc = recording.get_channel_locations() num_frames = recording.get_num_frames() duration = recording.frame_to_time(num_frames) print(f'Sampling frequency:{fs}') print(f'Channel ids:{channel_ids}') print(f'channel location:{channel_loc}') print(f'frame num:{num_frames}') print(f'recording duration:{duration}') # sorting_GT info unit_ids = sorting_GT.get_unit_ids() print(f'unit ids:{unit_ids}') return recording, sorting_GT
def kilosort2( recording_path, sorting_out, detect_threshold=6, car=True, # whether to do common average referencing minFR=1 / 50, # minimum spike rate (Hz), if a cluster falls below this for too long it gets removed freq_min=150, # min. bp filter freq (Hz), use 0 for no filter sigmaMask=30, # sigmaMask nPCs=3, # PCs per channel? ): from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor from ._kilosort2sorter import Kilosort2Sorter recording = AutoRecordingExtractor(dict(path=recording_path), download=True) # recording = se.SubRecordingExtractor(parent_recording=recording, start_frame=0, end_frame=30000 * 10) # Sorting print('Sorting...') sorter = Kilosort2Sorter(recording=recording, output_folder='/tmp/tmp_kilosort2_' + _random_string(8), delete_output_folder=True) sorter.set_params(detect_threshold=detect_threshold, car=car, minFR=minFR, freq_min=freq_min, sigmaMask=sigmaMask, nPCs=nPCs) timer = sorter.run() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer)) sorting = sorter.get_result() AutoSortingExtractor.write_sorting(sorting=sorting, save_path=sorting_out)
def spykingcircus(recording, sorting_out): import spiketoolkit as st import spikesorters as ss from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor import kachery as ka # TODO: need to think about how to deal with this ka.set_config(fr='default_readonly') recording = AutoRecordingExtractor(dict(path=recording), download=True) # Sorting print('Sorting...') sorter = ss.SpykingcircusSorter(recording=recording, output_folder='/tmp/tmp_spykingcircus_' + _random_string(8), delete_output_folder=True) sorter.set_params() timer = sorter.run() print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer)) sorting = sorter.get_result() AutoSortingExtractor.write_sorting(sorting=sorting, save_path=sorting_out)
from pathlib import Path import numpy as np import kachery as ka from pykilosort import Bunch, add_default_handler, run from spikeextractors.extractors import bindatrecordingextractor as dat from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor dat_path = Path("test/test.bin").absolute() dir_path = dat_path.parent ka.set_config(fr="default_readonly") recording_path = "sha1dir://c0879a26f92e4c876cd608ca79192a84d4382868.manual_franklab/tetrode_600s/sorter1_1" recording = AutoRecordingExtractor(recording_path, download=True) recording.write_to_binary_dat_format(str(dat_path)) n_channels = len(recording.get_channel_ids()) probe = Bunch() probe.NchanTOT = n_channels probe.chanMap = np.array(range(0, n_channels)) probe.kcoords = np.ones(n_channels) probe.xc = recording.get_channel_locations()[:, 0] probe.yc = recording.get_channel_locations()[:, 1] add_default_handler(level="DEBUG") params = {"nfilt_factor": 8, "AUCsplit": 0.85, "nskip": 5} run( dat_path,
def main(): import spikeextractors as se from spikeforest2_utils import writemda32, AutoRecordingExtractor from sklearn.neighbors import NearestNeighbors from sklearn.cross_decomposition import PLSRegression import spikeforest_widgets as sw sw.init_electron() # bandpass filter with hither.config(container='default', cache='default_readwrite'): recobj2 = filter_recording.run( recobj=recobj, freq_min=300, freq_max=6000, freq_wid=1000 ).retval detect_threshold = 3 detect_interval = 200 detect_interval_reference = 10 detect_sign = -1 num_events = 1000 snippet_len = (200, 200) window_frac = 0.3 num_passes = 20 npca = 100 max_t = 30000 * 100 k = 20 ncomp = 4 R = AutoRecordingExtractor(recobj2) X = R.get_traces() sig = X.copy() if detect_sign < 0: sig = -sig elif detect_sign == 0: sig = np.abs(sig) sig = np.max(sig, axis=0) noise_level = np.median(np.abs(sig)) / 0.6745 # median absolute deviation (MAD) times_reference = detect_on_channel(sig, detect_threshold=noise_level*detect_threshold, detect_interval=detect_interval_reference, detect_sign=1, margin=1000) times_reference = times_reference[times_reference <= max_t] print(f'Num. reference events = {len(times_reference)}') snippets_reference = extract_snippets(X, reference_frames=times_reference, snippet_len=snippet_len) tt = np.linspace(-1, 1, snippets_reference.shape[2]) window0 = np.exp(-tt**2/(2*window_frac**2)) for j in range(snippets_reference.shape[0]): for m in range(snippets_reference.shape[1]): snippets_reference[j, m, :] = snippets_reference[j, m, :] * window0 A_snippets_reference = snippets_reference.reshape(snippets_reference.shape[0], snippets_reference.shape[1] * snippets_reference.shape[2]) print('PCA...') u, s, vh = np.linalg.svd(A_snippets_reference) components_reference = vh[0:npca, :].T features_reference = A_snippets_reference @ components_reference print('Setting up nearest neighbors...') nbrs = NearestNeighbors(n_neighbors=k + 1, algorithm='ball_tree').fit(features_reference) X_signal = np.zeros((R.get_num_channels(), R.get_num_frames()), dtype=np.float32) for passnum in range(num_passes): print(f'Pass {passnum}') sig = X.copy() if detect_sign < 0: sig = -sig elif detect_sign == 0: sig = np.abs(sig) sig = np.max(sig, axis=0) noise_level = np.median(np.abs(sig)) / 0.6745 # median absolute deviation (MAD) times = detect_on_channel(sig, detect_threshold=noise_level*detect_threshold, detect_interval=detect_interval, detect_sign=1, margin=1000) times = times[times <= max_t] print(f'Number of events: {len(times)}') if len(times) == 0: break snippets = extract_snippets(X, reference_frames=times, snippet_len=snippet_len) for j in range(snippets.shape[0]): for m in range(snippets.shape[1]): snippets[j, m, :] = snippets[j, m, :] * window0 A_snippets = snippets.reshape(snippets.shape[0], snippets.shape[1] * snippets.shape[2]) features = A_snippets @ components_reference print('Finding nearest neighbors...') distances, indices = nbrs.kneighbors(features) features2 = np.zeros(features.shape, dtype=features.dtype) print('PLS regression...') for j in range(features.shape[0]): print(f'{j+1} of {features.shape[0]}') inds0 = np.squeeze(indices[j, :]) inds0 = inds0[1:] # TODO: it may not always be necessary to exclude the first -- how should we make that decision? f_neighbors = features_reference[inds0, :] pls = PLSRegression(n_components=ncomp) pls.fit(f_neighbors.T, features[j, :].T) features2[j, :] = pls.predict(f_neighbors.T).T A_snippets_denoised = features2 @ components_reference.T snippets_denoised = A_snippets_denoised.reshape(snippets.shape) for j in range(len(times)): t0 = times[j] snippet_denoised_0 = np.squeeze(snippets_denoised[j, :, :]) X_signal[:, t0-snippet_len[0]:t0+snippet_len[1]] = X_signal[:, t0-snippet_len[0]:t0+snippet_len[1]] + snippet_denoised_0 X[:, t0-snippet_len[0]:t0+snippet_len[1]] = X[:, t0-snippet_len[0]:t0+snippet_len[1]] - snippet_denoised_0 S = np.concatenate((X_signal, X, R.get_traces()), axis=0) with hither.TemporaryDirectory() as tmpdir: raw_fname = tmpdir + '/raw.mda' writemda32(S, raw_fname) sig_recobj = recobj2.copy() sig_recobj['raw'] = ka.store_file(raw_fname) sw.TimeseriesView(recording=AutoRecordingExtractor(sig_recobj)).show()
class TimeseriesView: def __init__(self): super().__init__() self._recording = None self._multiscale_recordings = None self._segment_size_times_num_channels = 1000000 self._segment_size = None def javascript_state_changed(self, prev_state, state): self._set_status('running', 'Running TimeseriesView') self._create_efficient_access = state.get('create_efficient_access', False) if not self._recording: self._set_status('running', 'Loading recording') recording0 = state.get('recording', None) if not recording0: self._set_error('Missing: recording') return try: self._recording = AutoRecordingExtractor(recording0) except Exception as err: traceback.print_exc() self._set_error('Problem initiating recording: {}'.format(err)) return self._set_status('running', 'Loading recording data') traces0 = self._recording.get_traces( channel_ids=self._recording.get_channel_ids(), start_frame=0, end_frame=min(self._recording.get_num_frames(), 25000)) y_offsets = -np.mean(traces0, axis=1) for m in range(traces0.shape[0]): traces0[m, :] = traces0[m, :] + y_offsets[m] vv = np.percentile(np.abs(traces0), 90) y_scale_factor = 1 / (2 * vv) if vv > 0 else 1 self._segment_size = int( np.ceil(self._segment_size_times_num_channels / self._recording.get_num_channels())) try: channel_locations = self._recording.get_channel_locations() except: channel_locations = None self.set_state( dict(num_channels=self._recording.get_num_channels(), channel_ids=self._recording.get_channel_ids(), channel_locations=channel_locations, num_timepoints=self._recording.get_num_frames(), y_offsets=y_offsets, y_scale_factor=y_scale_factor, samplerate=self._recording.get_sampling_frequency(), segment_size=self._segment_size, status_message='Loaded recording.')) # SR = state.get('segmentsRequested', {}) # for key in SR.keys(): # aa = SR[key] # if not self.get_python_state(key, None): # self.set_state(dict(status_message='Loading segment {}'.format(key))) # data0 = self._load_data(aa['ds'], aa['ss']) # data0_base64 = _mda32_to_base64(data0) # state0 = {} # state0[key] = dict(data=data0_base64, ds=aa['ds'], ss=aa['ss']) # self.set_state(state0) # self.set_state(dict(status_message='Loaded segment {}'.format(key))) self._set_status('finished', '') def on_message(self, msg): if msg['command'] == 'requestSegment': ds = msg['ds_factor'] ss = msg['segment_num'] data0 = self._load_data(ds, ss) data0_base64 = _mda32_to_base64(data0) self.send_message( dict(command='setSegment', ds_factor=ds, segment_num=ss, data=data0_base64)) def _load_data(self, ds, ss): if not self._recording: return logger.info('_load_data {} {}'.format(ds, ss)) if ds > 1: if self._multiscale_recordings is None: self.set_state( dict(status_message='Creating multiscale recordings...')) self._multiscale_recordings = _create_multiscale_recordings( recording=self._recording, progressive_ds_factor=3, create_efficient_access=self._create_efficient_access) self.set_state( dict(status_message='Done creating multiscale recording')) rx = self._multiscale_recordings[ds] # print('_extract_data_segment', ds, ss, self._segment_size) start_time = time.time() X = _extract_data_segment(recording=rx, segment_num=ss, segment_size=self._segment_size * 2) # print('done extracting data segment', time.time() - start_time) logger.info('extracted data segment {} {} {}'.format( ds, ss, time.time() - start_time)) return X start_time = time.time() traces = self._recording.get_traces( start_frame=ss * self._segment_size, end_frame=(ss + 1) * self._segment_size) logger.info('extracted data segment {} {} {}'.format( ds, ss, time.time() - start_time)) return traces def iterate(self): pass def _set_state(self, **kwargs): self.set_state(kwargs) def _set_error(self, error_message): self._set_status('error', error_message) def _set_status(self, status, status_message=''): self._set_state(status=status, status_message=status_message)
import argparse import spikeforest_widgets as sw from spikeforest2_utils import AutoRecordingExtractor import kachery as ka sw.init_electron() ka.set_config(fr='default_readonly') parser = argparse.ArgumentParser(description='Browse a SpikeForest analysis') # parser.add_argument('--path', help='Path to the analysis JSON file', required=True) args = parser.parse_args() R = AutoRecordingExtractor('sha1dir://49b1fe491cbb4e0f90bde9cfc31b64f985870528.paired_boyden32c/531_2_1') X = sw.TimeseriesView( recording = R ) X.show()
def compute_units_info(recording_path, sorting_path, json_out): recording = AutoRecordingExtractor(recording_path) sorting = AutoSortingExtractor(sorting_path, samplerate=recording.get_sampling_frequency()) obj = _compute_units_info(recording=recording, sorting=sorting) with open(json_out, 'w') as f: json.dump(obj, f)
#Downloading the recording objects ka.set_config(fr='default_readonly') print( f'Downloading recording: {recordingZero["studyName"]}/{recordingZero["name"]}' ) ka.load_file(recordingZero['directory'] + '/raw.mda') ka.load_file(recordingZero['directory'] + '/params.json') ka.load_file(recordingZero['directory'] + '/geom.csv') ka.load_file(recordingZero['directory'] + '/firings_true.mda') #Attaching the results recordingZero['results'] = dict() #Tryting to plot the recordings recordingInput = AutoRecordingExtractor(dict(path=recordingPath), download=True) w_ts = sw.plot_timeseries(recordingInput) w_ts.figure.suptitle("Recording by group") w_ts.ax.set_ylabel("Channel_ids") #We will also try to plot the rastor plot for the ground truth gtOutput = AutoSortingExtractor(sortingPath) #We need to change the indices of the ground truth output w_rs_gt = sw.plot_rasters(gtOutput, sampling_frequency=sampleRate) #Spike-Sorting #trying to run SPYKINGCIRCUS through spike interface #spykingcircus with ka.config(fr='default_readonly'): #with hither.config(cache='default_readwrite'): with hither.config(container='default'):
#Downloading the recording objects ka.set_config(fr='default_readonly') print( f'Downloading recording: {recordingZero["studyName"]}/{recordingZero["name"]}' ) ka.load_file(recordingZero['directory'] + '/raw.mda') ka.load_file(recordingZero['directory'] + '/params.json') ka.load_file(recordingZero['directory'] + '/geom.csv') ka.load_file(recordingZero['directory'] + '/firings_true.mda') #Attaching the results recordingZero['results'] = dict() #Tryting to plot the recordings recordingInput = AutoRecordingExtractor(dict(path=recordingPath), download=True) w_ts = sw.plot_timeseries(recordingInput) w_ts.figure.suptitle("Recording by group") w_ts.ax.set_ylabel("Channel_ids") gtOutput = AutoSortingExtractor(sortingPath) #Only getting a part of the recording #gtOutput.add_epoch(epoch_name="first_half", start_frame=0, end_frame=recordingInput.get_num_frames()/2) #set #subsorting = gtOutput.get_epoch("first_half") #w_rs_gt = sw.plot_rasters(gtOutput,sampling_frequency=sampleRate) #w_wf_gt = sw.plot_unit_waveforms(recordingInput,gtOutput, max_spikes_per_unit=100) #We will also try to plot the rastor plot for the ground truth
def ironclust(recording_path, sorting_out, detect_threshold=4, freq_min=300, freq_max=0, detect_sign=-1, adjacency_radius=50, whiten=False, adjacency_radius_out=100, merge_thresh=0.95, fft_thresh=8, knn=30, min_count=30, delta_cut=1, pc_per_chan=6, batch_sec_drift=600, step_sec_drift=20, common_ref_type='trimmean', fGpu=True, clip_pre=0.25, clip_post=0.75, merge_thresh_cc=1): from spikeforest2_utils import AutoRecordingExtractor, AutoSortingExtractor from ._ironclustsorter import IronClustSorter recording = AutoRecordingExtractor(dict(path=recording_path), download=True) # Sorting print('Sorting...') sorter = IronClustSorter(recording=recording, output_folder='/tmp/tmp_ironclust_' + _random_string(8), delete_output_folder=True) sorter.set_params(fft_thresh_low=0, nSites_whiten=32, feature_type='gpca', post_merge_mode=1, sort_mode=1, prm_template_name='', filter_type='bandpass', filter_detect_type='none', detect_threshold=detect_threshold, freq_min=freq_min, freq_max=freq_max, detect_sign=detect_sign, adjacency_radius=adjacency_radius, whiten=whiten, adjacency_radius_out=adjacency_radius_out, merge_thresh=merge_thresh, fft_thresh=fft_thresh, knn=knn, min_count=min_count, delta_cut=delta_cut, pc_per_chan=pc_per_chan, batch_sec_drift=batch_sec_drift, step_sec_drift=step_sec_drift, common_ref_type=common_ref_type, fGpu=fGpu, clip_pre=clip_pre, clip_post=clip_post, merge_thresh_cc=merge_thresh_cc) timer = sorter.run() #print('#SF-SORTER-RUNTIME#{:.3f}#'.format(timer)) sorting = sorter.get_result() AutoSortingExtractor.write_sorting(sorting=sorting, save_path=sorting_out)