def test_convert_recording_extractor_to_nwb(self, se_class, dataset_path, se_kwargs): print(f"\n\n\n TESTING {se_class.extractor_name}...") dataset_stem = Path(dataset_path).stem self.dataset.get(dataset_path) recording = se_class(**se_kwargs) # # test writing to NWB if test_nwb: nwb_save_path = self.savedir / f"{se_class.__name__}_test_{dataset_stem}.nwb" se.NwbRecordingExtractor.write_recording(recording, nwb_save_path, write_scaled=True) nwb_recording = se.NwbRecordingExtractor(nwb_save_path) check_recordings_equal(recording, nwb_recording) if recording.has_unscaled: nwb_save_path_unscaled = self.savedir / f"{se_class.__name__}_test_{dataset_stem}_unscaled.nwb" if np.all(recording.get_channel_offsets() == 0): se.NwbRecordingExtractor.write_recording(recording, nwb_save_path_unscaled, write_scaled=False) nwb_recording = se.NwbRecordingExtractor(nwb_save_path_unscaled) check_recordings_equal(recording, nwb_recording, return_scaled=False) # Skip check when NWB converts uint to int if recording.get_dtype(return_scaled=False) == nwb_recording.get_dtype(return_scaled=False): check_recordings_equal(recording, nwb_recording, return_scaled=True) # test caching if test_caching: rec_cache = se.CacheRecordingExtractor(recording) check_recordings_equal(recording, rec_cache) if recording.has_unscaled: rec_cache_unscaled = se.CacheRecordingExtractor(recording, return_scaled=False) check_recordings_equal(recording, rec_cache_unscaled, return_scaled=False) check_recordings_equal(recording, rec_cache_unscaled, return_scaled=True)
def test_cache_extractor(self): cache_rec = se.CacheRecordingExtractor(self.RX) check_recording_return_types(cache_rec) check_recordings_equal(self.RX, cache_rec) cache_rec.move_to('cache_rec') assert cache_rec.filename == 'cache_rec.dat' check_dumping(cache_rec) cache_rec = se.CacheRecordingExtractor(self.RX, save_path='cache_rec2') check_recording_return_types(cache_rec) check_recordings_equal(self.RX, cache_rec) assert cache_rec.filename == 'cache_rec2.dat' check_dumping(cache_rec) # test saving to file del cache_rec assert Path('cache_rec2.dat').is_file() # test tmp cache_rec = se.CacheRecordingExtractor(self.RX) tmp_file = cache_rec.filename del cache_rec assert not Path(tmp_file).is_file() cache_sort = se.CacheSortingExtractor(self.SX) check_sorting_return_types(cache_sort) check_sortings_equal(self.SX, cache_sort) cache_sort.move_to('cache_sort') assert cache_sort.filename == 'cache_sort.npz' check_dumping(cache_sort) # test saving to file del cache_sort assert Path('cache_sort.npz').is_file() cache_sort = se.CacheSortingExtractor(self.SX, save_path='cache_sort2') check_sorting_return_types(cache_sort) check_sortings_equal(self.SX, cache_sort) assert cache_sort.filename == 'cache_sort2.npz' check_dumping(cache_sort) # test saving to file del cache_sort assert Path('cache_sort2.npz').is_file() # test tmp cache_sort = se.CacheSortingExtractor(self.SX) tmp_file = cache_sort.filename del cache_sort assert not Path(tmp_file).is_file() # cleanup os.remove('cache_rec.dat') os.remove('cache_rec2.dat') os.remove('cache_sort.npz') os.remove('cache_sort2.npz')
def save_si_object(object_name: str, si_object, output_folder, cache_raw=False, include_properties=True, include_features=False): """ Save an arbitrary SI object to a temprary location for NWB conversion. Parameters ---------- object_name: str The unique name of the SpikeInterface object. si_object: RecordingExtractor or SortingExtractor The extractor to be saved. output_folder: str or Path The folder where the object is saved. cache_raw: bool If True, the Extractor is cached to a binary file (not recommended for RecordingExtractor objects) (default False). include_properties: bool If True, properties (channel or unit) are saved (default True). include_features: bool If True, spike features are saved (default False) """ Path(output_folder).mkdir(parents=True, exist_ok=True) if isinstance(si_object, se.RecordingExtractor): if not si_object.is_dumpable: cache = se.CacheRecordingExtractor(si_object, save_path=output_folder / "raw.dat") elif cache_raw: # save to json before caching to keep history (in case it's needed) json_file = output_folder / f"{object_name}.json" si_object.dump_to_json(output_folder / json_file) cache = se.CacheRecordingExtractor(si_object, save_path=output_folder / "raw.dat") else: cache = si_object elif isinstance(si_object, se.SortingExtractor): if not si_object.is_dumpable: cache = se.CacheSortingExtractor(si_object, save_path=output_folder / "sorting.npz") elif cache_raw: # save to json before caching to keep history (in case it's needed) json_file = output_folder / f"{object_name}.json" si_object.dump_to_json(output_folder / json_file) cache = se.CacheSortingExtractor(si_object, save_path=output_folder / "sorting.npz") else: cache = si_object else: raise ValueError("The 'si_object' argument shoulde be a SpikeInterface Extractor!") pkl_file = output_folder / f"{object_name}.pkl" cache.dump_to_pickle( output_folder / pkl_file, include_properties=include_properties, include_features=include_features )
def bandpass_filter(recording, freq_min=300, freq_max=6000, freq_wid=1000, filter_type='fft', order=3, chunk_size=30000, cache_to_file=False, cache_chunks=False, dtype=None): ''' Performs a lazy filter on the recording extractor traces. Parameters ---------- recording: RecordingExtractor The recording extractor to be filtered. freq_min: int or float High-pass cutoff frequency. freq_max: int or float Low-pass cutoff frequency. freq_wid: int or float Width of the filter (when type is 'fft'). filter_type: str 'fft' or 'butter'. The 'fft' filter uses a kernel in the frequency domain. The 'butter' filter uses scipy butter and filtfilt functions. order: int Order of the filter (if 'butter'). chunk_size: int The chunk size to be used for the filtering. cache_to_file: bool (default False). If True, filtered traces are computed and cached all at once on disk in temp file cache_chunks: bool (default False). If True then each chunk is cached in memory (in a dict) dtype: dtype The dtype of the traces Returns ------- filter_recording: BandpassFilterRecording The filtered recording extractor object ''' if cache_to_file: assert not cache_chunks, 'if cache_to_file cache_chunks should be False' bpf_recording = BandpassFilterRecording(recording=recording, freq_min=freq_min, freq_max=freq_max, freq_wid=freq_wid, filter_type=filter_type, order=order, chunk_size=chunk_size, cache_chunks=cache_chunks, dtype=dtype) if cache_to_file: return se.CacheRecordingExtractor(bpf_recording, chunk_size=chunk_size) else: return bpf_recording
def test_cache_extractor(self): cache_extractor = se.CacheRecordingExtractor(self.RX) self._check_recording_return_types(cache_extractor) self._check_recordings_equal(self.RX, cache_extractor) cache_extractor.save_to_file('cache') assert cache_extractor.get_filename() == 'cache.dat' del cache_extractor assert not Path('cache.dat').is_file()
def test_cache_extractor(self): cache_extractor = se.CacheRecordingExtractor(self.RX) check_recording_return_types(cache_extractor) check_recordings_equal(self.RX, cache_extractor) cache_extractor.save_to_file('cache') assert cache_extractor.filename == 'cache.dat' check_dumping(cache_extractor) # test saving to file del cache_extractor assert Path('cache.dat').is_file() # test tmp cache_extractor = se.CacheRecordingExtractor(self.RX) tmp_file = cache_extractor.filename del cache_extractor assert not Path(tmp_file).is_file()
def notch_filter(recording, freq=3000, q=30, chunk_size=30000, cache_to_file=False, cache_chunks=False): ''' Performs a notch filter on the recording extractor traces using scipy iirnotch function. Parameters ---------- recording: RecordingExtractor The recording extractor to be notch-filtered. freq: int or float The target frequency of the notch filter. q: int The quality factor of the notch filter. chunk_size: int The chunk size to be used for the filtering. cache_to_file: bool (default False). If True, filtered traces are computed and cached all at once on disk in temp file cache_chunks: bool (default False). If True then each chunk is cached in memory (in a dict) Returns ------- filter_recording: NotchFilterRecording The notch-filtered recording extractor object ''' if cache_to_file: assert not cache_chunks, 'if cache_to_file cache_chunks should be False' notch_recording = NotchFilterRecording( recording=recording, freq=freq, q=q, chunk_size=chunk_size, cache_chunks=cache_chunks, ) if cache_to_file: return se.CacheRecordingExtractor(notch_recording, chunk_size=chunk_size) else: return notch_recording
def make(self, key): key['analysis_file_name'] = AnalysisNwbfile().create(key['nwb_file_name']) # get the valid times. # NOTE: we will sort independently between each entry in the valid times list sort_intervals = (SortIntervalList() & {'nwb_file_name' : key['nwb_file_name'], 'sort_interval_list_name' : key['sort_interval_list_name']})\ .fetch1('sort_intervals') interval_list_name = (SpikeSortingParameters() & key).fetch1('interval_list_name') valid_times = (IntervalList() & {'nwb_file_name' : key['nwb_file_name'], 'interval_list_name' : interval_list_name})\ .fetch('valid_times')[0] raw_data_obj = (Raw() & {'nwb_file_name' : key['nwb_file_name']}).fetch_nwb()[0]['raw'] timestamps = np.asarray(raw_data_obj.timestamps) sampling_rate = estimate_sampling_rate(timestamps[0:100000], 1.5) units = dict() units_valid_times = dict() units_sort_interval = dict() units_templates = dict() units_waveforms = dict() # we will add an offset to the unit_id for each sort interval to avoid duplicating ids unit_id_offset = 0 #interate through the arrays of sort intervals, sorting each interval separately for sort_interval in sort_intervals: # Get the list of valid times for this sort interval recording_extractor, sort_interval_valid_times = self.get_recording_extractor(key, sort_interval) sort_parameters = (SpikeSorterParameters() & {'sorter_name': key['sorter_name'], 'parameter_set_name': key['parameter_set_name']}).fetch1() # get a name for the recording extractor for this sort interval recording_extractor_path = os.path.join(os.environ['SPIKE_SORTING_STORAGE_DIR'], key['analysis_file_name'], '_', str(sort_interval)) recording_extractor_cached = se.CacheRecordingExtractor(recording_extractor, save_path=recording_extractor_path) print(f'Sorting {key}...') sort = si.sorters.run_mountainsort4(recording=recording_extractor_cached, **sort_parameters['parameter_dict'], grouping_property='group', output_folder=os.getenv('SORTING_TEMP_DIR', None)) # create a stack of labelled arrays of the sorted spike times timestamps = np.asarray(raw_data_obj.timestamps) unit_ids = sort.get_unit_ids() # get the waveforms; we may want to specifiy these parameters more flexibly in the future waveform_params = st.postprocessing.get_waveforms_params() print(sort_parameters) waveform_params['grouping_property'] = 'group' # set the window to half of the clip size before and half after waveform_params['ms_before'] = sort_parameters['parameter_dict']['clip_size'] / sampling_rate * 1000 / 2 waveform_params['ms_after'] = waveform_params['ms_before'] waveform_params['max_spikes_per_unit'] = 1000 waveform_params['dtype'] = 'i2' #template_params['n_jobs'] = 7 waveform_params['verbose'] = False #print(f'template_params: {template_params}') templates = st.postprocessing.get_unit_templates(recording_extractor_cached, sort, **waveform_params) # for the waveforms themselves we only need to change the max_spikes_per_unit: waveform_params['max_spikes_per_unit'] = 1e100 waveforms = st.postprocessing.get_unit_waveforms(recording_extractor_cached, sort, unit_ids, **waveform_params) for index, unit_id in enumerate(unit_ids): current_index = unit_id + unit_id_offset unit_spike_samples = sort.get_unit_spike_train(unit_id=unit_id) #print(f'template for {unit_id}: {unit_templates[unit_id]} ') units[current_index] = timestamps[unit_spike_samples] # the templates are zero based, so we have to use the index here. units_templates[current_index] = templates[index] units_waveforms[current_index] = waveforms[index] units_valid_times[current_index] = sort_interval_valid_times units_sort_interval[current_index] = [sort_interval] if len(unit_ids) > 0: unit_id_offset += np.max(unit_ids) + 1 #Add the units to the Analysis file # TODO: consider replacing with spikeinterface call if possible units_object_id, units_waveforms_object_id = AnalysisNwbfile().add_units(key['analysis_file_name'], units, units_templates, units_valid_times, units_sort_interval, units_waveforms=units_waveforms) key['units_object_id'] = units_object_id key['units_waveforms_object_id'] = units_waveforms_object_id self.insert1(key)
def process_openephys(project, action_id, probe_path, sorter, acquisition_folder=None, exdir_file_path=None, spikesort=True, compute_lfp=True, compute_mua=False, parallel=False, spikesorter_params=None, server=None, bad_channels=None, ref=None, split=None, sort_by=None, ms_before_wf=1, ms_after_wf=2, bad_threshold=2, firing_rate_threshold=0, isi_viol_threshold=0): import spikeextractors as se import spiketoolkit as st import spikesorters as ss bad_channels = bad_channels or [] proc_start = time.time() if server is None or server == 'local': if acquisition_folder is None: action = project.actions[action_id] # if exdir_path is None: exdir_path = _get_data_path(action) exdir_file = exdir.File(exdir_path, plugins=exdir.plugins.quantities) acquisition = exdir_file["acquisition"] if acquisition.attrs['acquisition_system'] is None: raise ValueError('No Open Ephys aquisition system ' + 'related to this action') openephys_session = acquisition.attrs["session"] openephys_path = Path(acquisition.directory) / openephys_session else: openephys_path = Path(acquisition_folder) assert exdir_file_path is not None exdir_path = Path(exdir_file_path) probe_path = probe_path or project.config.get('probe') recording = se.OpenEphysRecordingExtractor(str(openephys_path)) recording = recording.load_probe_file(probe_path) if 'auto' not in bad_channels and len(bad_channels) > 0: recording_active = st.preprocessing.remove_bad_channels(recording, bad_channel_ids=bad_channels) else: recording_active = recording # apply filtering and cmr print('Writing filtered and common referenced data') tmp_folder = Path(f"tmp_{action_id}_si") freq_min_hp = 300 freq_max_hp = 3000 freq_min_lfp = 1 freq_max_lfp = 300 freq_resample_lfp = 1000 freq_resample_mua = 1000 type_hp = 'butter' order_hp = 5 recording_hp = st.preprocessing.bandpass_filter( recording_active, freq_min=freq_min_hp, freq_max=freq_max_hp, filter_type=type_hp, order=order_hp) if ref is not None: if ref.lower() == 'cmr': reference = 'median' elif ref.lower() == 'car': reference = 'average' else: raise Exception("'reference' can be either 'cmr' or 'car'") if split == 'all': recording_cmr = st.preprocessing.common_reference(recording_hp, reference=reference) elif split == 'half': groups = [recording.get_channel_ids()[:int(len(recording.get_channel_ids()) / 2)], recording.get_channel_ids()[int(len(recording.get_channel_ids()) / 2):]] recording_cmr = st.preprocessing.common_reference(recording_hp, groups=groups, reference=reference) else: if isinstance(split, list): recording_cmr = st.preprocessing.common_reference(recording_hp, groups=split, reference=reference) else: raise Exception("'split' must be a list of lists") else: recording_cmr = recording if 'auto' in bad_channels: recording_cmr = st.preprocessing.remove_bad_channels(recording_cmr, bad_channel_ids=None, bad_threshold=bad_threshold, seconds=10) recording_active = se.SubRecordingExtractor( recording, channel_ids=recording_cmr.active_channels) print("Active channels: ", len(recording_active.get_channel_ids())) recording_lfp = st.preprocessing.bandpass_filter( recording_active, freq_min=freq_min_lfp, freq_max=freq_max_lfp) recording_lfp = st.preprocessing.resample( recording_lfp, freq_resample_lfp) recording_mua = st.preprocessing.resample( st.preprocessing.rectify(recording_active), freq_resample_mua) if spikesort: print('Bandpass filter') t_start = time.time() recording_cmr = se.CacheRecordingExtractor(recording_cmr, save_path=tmp_folder / 'filt.dat') print('Filter time: ', time.time() - t_start) if compute_lfp: print('Computing LFP') t_start = time.time() recording_lfp = se.CacheRecordingExtractor(recording_lfp, save_path=tmp_folder / 'lfp.dat') print('Filter time: ', time.time() - t_start) if compute_mua: print('Computing MUA') t_start = time.time() recording_mua = se.CacheRecordingExtractor(recording_mua, save_path=tmp_folder / 'mua.dat') print('Filter time: ', time.time() - t_start) print('Number of channels', recording_cmr.get_num_channels()) if spikesort: try: # save attributes exdir_group = exdir.File(exdir_path, plugins=exdir.plugins.quantities) ephys = exdir_group.require_group('processing').require_group('electrophysiology') spikesorting = ephys.require_group('spikesorting') sorting_group = spikesorting.require_group(sorter) output_folder = sorting_group.require_raw('output').directory if 'kilosort' in sorter: sorting = ss.run_sorter(sorter, recording_cmr, parallel=parallel, verbose=True, delete_output_folder=True, **spikesorter_params) else: sorting = ss.run_sorter( sorter, recording_cmr, parallel=parallel, grouping_property=sort_by, verbose=True, output_folder=output_folder, delete_output_folder=True, **spikesorter_params) spike_sorting_attrs = {'name': sorter, 'params': spikesorter_params} filter_attrs = {'hp_filter': {'low': freq_min_hp, 'high': freq_max_hp}, 'lfp_filter': {'low': freq_min_lfp, 'high': freq_max_lfp, 'resample': freq_resample_lfp}, 'mua_filter': {'resample': freq_resample_mua}} reference_attrs = {'type': str(ref), 'split': str(split)} sorting_group.attrs.update({'spike_sorting': spike_sorting_attrs, 'filter': filter_attrs, 'reference': reference_attrs}) except Exception as e: try: shutil.rmtree(tmp_folder) except: print(f'Could not tmp processing folder: {tmp_folder}') raise Exception("Spike sorting failed") print('Found ', len(sorting.get_unit_ids()), ' units!') # extract waveforms if spikesort: # se.ExdirSortingExtractor.write_sorting( # sorting, exdir_path, recording=recording_cmr, verbose=True) print('Saving Phy output') phy_folder = sorting_group.require_raw('phy').directory if firing_rate_threshold > 0: sorting_min = st.curation.threshold_firing_rates(sorting, threshold=firing_rate_threshold, threshold_sign='less', duration_in_frames=recording_cmr.get_num_frames()) print("Removed ", (len(sorting.get_unit_ids()) - len(sorting_min.get_unit_ids())), 'units with less than', firing_rate_threshold, 'firing rate') sorting_min = sorting if isi_viol_threshold > 0: sorting_viol = st.curation.threshold_isi_violations(sorting_min, threshold=isi_viol_threshold, threshold_sign='greater', duration_in_frames=recording_cmr.get_num_frames()) print("Removed ", (len(sorting_min.get_unit_ids()) - len(sorting_viol.get_unit_ids())), 'units with ISI violation greater than', isi_viol_threshold) else: sorting_viol = sorting_min t_start_save = time.time() sorting_viol.set_tmp_folder(tmp_folder) st.postprocessing.export_to_phy(recording_cmr, sorting_viol, output_folder=phy_folder, ms_before=ms_before_wf, ms_after=ms_after_wf, verbose=True, grouping_property=sort_by, recompute_info=True, save_as_property_or_feature=True) print('Save to phy time:', time.time() - t_start_save) if compute_lfp: print('Saving LFP to exdir format') se.ExdirRecordingExtractor.write_recording( recording_lfp, exdir_path, lfp=True) if compute_mua: print('Saving MUA to exdir format') se.ExdirRecordingExtractor.write_recording( recording_mua, exdir_path, mua=True) # save attributes exdir_group = exdir.File(exdir_path, plugins=exdir.plugins.quantities) ephys = exdir_group.require_group('processing').require_group('electrophysiology') spike_sorting_attrs = {'name': sorter, 'params': spikesorter_params} filter_attrs = {'hp_filter': {'low': freq_min_hp, 'high': freq_max_hp}, 'lfp_filter': {'low': freq_min_lfp, 'high': freq_max_lfp, 'resample': freq_resample_lfp}, 'mua_filter': {'resample': freq_resample_mua}} reference_attrs = {'type': str(ref), 'split': str(split)} ephys.attrs.update({'spike_sorting': spike_sorting_attrs, 'filter': filter_attrs, 'reference': reference_attrs}) try: if spikesort: del recording_cmr if compute_lfp: del recording_lfp if compute_mua: del recording_mua shutil.rmtree(tmp_folder) except: print(f'Could not tmp processing folder: {tmp_folder}') else: config = expipe.config._load_config_by_name(None) assert server in [s['host'] for s in config.get('servers')] server_dict = [s for s in config.get('servers') if s['host'] == server][0] host = server_dict['domain'] user = server_dict['user'] password = server_dict['password'] port = 22 # host, user, pas, port = utils.get_login( # hostname=hostname, username=username, port=port, password=password) ssh, scp_client, sftp_client, pbar = utils.login( hostname=host, username=user, password=password, port=port) print('Invoking remote shell') remote_shell = utils.ShellHandler(ssh) ########################## SEND ####################################### action = project.actions[action_id] # if exdir_path is None: exdir_path = _get_data_path(action) exdir_path_str = str(exdir_path) exdir_file = exdir.File(exdir_path, plugins=exdir.plugins.quantities) acquisition = exdir_file["acquisition"] if acquisition.attrs['acquisition_system'] is None: raise ValueError('No Open Ephys aquisition system ' + 'related to this action') openephys_session = acquisition.attrs["session"] openephys_path = Path(acquisition.directory) / openephys_session print('Initializing transfer of "' + str(openephys_path) + '" to "' + host + '"') try: # make directory for untaring process_folder = '/tmp/process_' + str(np.random.randint(10000000)) stdin, stdout, stderr = remote_shell.execute('mkdir ' + process_folder) except IOError: pass print('Packing tar archive') remote_acq = process_folder + '/acquisition' remote_tar = process_folder + '/acquisition.tar' # transfer acquisition folder local_tar = shutil.make_archive(str(openephys_path), 'tar', str(openephys_path)) print(local_tar) scp_client.put( local_tar, remote_tar, recursive=False) # transfer probe_file remote_probe = process_folder + '/probe.prb' scp_client.put( probe_path, remote_probe, recursive=False) remote_exdir = process_folder + '/main.exdir' remote_proc = process_folder + '/main.exdir/processing' remote_proc_tar = process_folder + '/processing.tar' local_proc = str(exdir_path / 'processing') local_proc_tar = local_proc + '.tar' # transfer spike params if spikesorter_params is not None: spike_params_file = 'spike_params.yaml' with open(spike_params_file, 'w') as f: yaml.dump(spikesorter_params, f) remote_yaml = process_folder + '/' + spike_params_file scp_client.put( spike_params_file, remote_yaml, recursive=False) try: os.remove(spike_params_file) except: print('Could not remove: ', spike_params_file) else: remote_yaml = 'none' extra_args = "" if not compute_lfp: extra_args = extra_args + ' --no-lfp' if not compute_mua: extra_args = extra_args + ' --no-mua' if not spikesort: extra_args = extra_args + ' --no-sorting' extra_args = extra_args + ' -bt {}'.format(bad_threshold) if ref is not None and isinstance(ref, str): ref = ref.lower() if split is not None and isinstance(split, str): split = split.lower() bad_channels_cmd = '' for bc in bad_channels: bad_channels_cmd = bad_channels_cmd + ' -bc ' + str(bc) ref_cmd = '' if ref is not None: ref_cmd = ' --ref ' + ref.lower() split_cmd = '' if split is not None: split_cmd = ' --split-channels ' + str(split) par_cmd = '' if not parallel: par_cmd = ' --no-par ' sortby_cmd = '' if sort_by is not None: sortby_cmd = ' --sort-by ' + sort_by wf_cmd = ' --ms-before-wf ' + str(ms_before_wf) + ' --ms-after-wf ' + str(ms_after_wf) ms_cmd = ' --min-fr ' + str(firing_rate_threshold) isi_cmd = ' --min-isi ' + str(isi_viol_threshold) try: pbar[0].close() except Exception: pass print('Making acquisition folder') cmd = "mkdir " + remote_acq print('Shell: ', cmd) stdin, stdout, stderr = remote_shell.execute("mkdir " + remote_acq) # utils.ssh_execute(ssh, "mkdir " + remote_acq) print('Unpacking tar archive') cmd = "tar -xf " + remote_tar + " --directory " + remote_acq stdin, stdout, stderr = remote_shell.execute(cmd) # utils.ssh_execute(ssh, cmd) print('Deleting tar archives') if not os.access(str(local_tar), os.W_OK): # Is the error an access error ? os.chmod(str(local_tar), stat.S_IWUSR) try: os.remove(str(local_tar)) except: print('Could not remove: ', local_tar) ###################### PROCESS ####################################### print('Processing on server') cmd = "expipe process openephys {} --probe-path {} --sorter {} --spike-params {} " \ "--acquisition {} --exdir-path {} {} {} {} {} {} {} {} {} {}".format( action_id, remote_probe, sorter, remote_yaml, remote_acq, remote_exdir, bad_channels_cmd, ref_cmd, par_cmd, sortby_cmd, split_cmd, wf_cmd, extra_args, ms_cmd, isi_cmd) stdin, stdout, stderr = remote_shell.execute(cmd, print_lines=True) print('Finished remote processing') ####################### RETURN PROCESSED DATA ####################### print('Initializing transfer of "' + remote_proc + '" to "' + local_proc + '"') print('Packing tar archive') cmd = "tar -C " + remote_exdir + " -cf " + remote_proc_tar + ' processing' stdin, stdout, stderr = remote_shell.execute(cmd, print_lines=True) # cmd = "ls -l " + remote_proc_tar # stdin, stdout, stderr = remote_shell.execute(cmd, print_lines=False) # wait for 5 seconds to ensure that packing is done time.sleep(5) # utils.ssh_execute(ssh, "tar -C " + remote_exdir + " -cf " + remote_proc_tar + ' processing') scp_client.get(remote_proc_tar, local_proc_tar, recursive=False) try: pbar[0].close() except Exception: pass print('Unpacking tar archive') if 'processing' in exdir_file: if 'electrophysiology' in exdir_file['processing']: print('Merging with old processing/electrophysiology') with tarfile.open(str(local_proc_tar)) as tar: _ = [tar.extract(m, exdir_path_str) for m in tar.getmembers() if 'tracking' not in m.name and 'exdir.yaml' in m.name] if spikesort: _ = [tar.extract(m, exdir_path_str) for m in tar.getmembers() if 'spikesorting' in m.name and sorter in m.name] if compute_lfp: _ = [tar.extract(m, exdir_path_str) for m in tar.getmembers() if 'LFP' in m.name] _ = [tar.extract(m, exdir_path_str) for m in tar.getmembers() if 'group' in m.name and 'attributes' in m.name] if compute_lfp: _ = [tar.extract(m, exdir_path_str) for m in tar.getmembers() if 'MUA' in m.name] _ = [tar.extract(m, exdir_path_str) for m in tar.getmembers() if 'group' in m.name and 'attributes' in m.name] print('Deleting tar archives') if not os.access(str(local_proc_tar), os.W_OK): # Is the error an access error ? os.chmod(str(local_proc_tar), stat.S_IWUSR) try: os.remove(str(local_proc_tar)) except: print('Could not remove: ', local_proc_tar) # sftp_client.remove(remote_proc_tar) print('Deleting remote process folder') cmd = "rm -rf " + process_folder stdin, stdout, stderr = remote_shell.execute(cmd) #################### CLOSE UP ############################# ssh.close() sftp_client.close() scp_client.close() # check for tracking and events (always locally) oe_recording = pyopenephys.File(str(openephys_path)).experiments[0].recordings[0] if len(oe_recording.tracking) > 0: print('Saving ', len(oe_recording.tracking), ' Open Ephys tracking sources') generate_tracking(exdir_path, oe_recording) if len(oe_recording.events) > 0: print('Saving ', len(oe_recording.events), ' Open Ephys event sources') generate_events(exdir_path, oe_recording) print('Saved to exdir: ', exdir_path) print("Total elapsed time: ", time.time() - proc_start) print('Finished processing')
def test_waveforms(): n_wf_samples = 100 n_jobs = [0, 2] for n in n_jobs: for m in memmaps: print('N jobs', n, 'memmap', m) folder = 'test' if os.path.isdir(folder): shutil.rmtree(folder) rec, sort, waveforms, templates, max_chans, amps = create_signal_with_known_waveforms( n_waveforms=2, n_channels=4, n_wf_samples=n_wf_samples) rec, sort = create_dumpable_extractors_from_existing( folder, rec, sort) # get num samples in ms ms_cut = n_wf_samples // 2 / rec.get_sampling_frequency() * 1000 # no group wav = get_unit_waveforms(rec, sort, ms_before=ms_cut, ms_after=ms_cut, save_property_or_features=False, n_jobs=n, memmap=m, recompute_info=True) for (w, w_gt) in zip(wav, waveforms): assert np.allclose(w, w_gt) assert 'waveforms' not in sort.get_shared_unit_spike_feature_names( ) # small chunks wav = get_unit_waveforms(rec, sort, ms_before=ms_cut, ms_after=ms_cut, save_property_or_features=False, n_jobs=n, memmap=m, chunk_mb=5, recompute_info=True) for (w, w_gt) in zip(wav, waveforms): assert np.allclose(w, w_gt) assert 'waveforms' not in sort.get_shared_unit_spike_feature_names( ) # return_scaled gain = 0.1 rec_sc, sort_sc = se.example_datasets.toy_example() rec_sc.set_channel_gains(gain) rec_sc.has_unscaled = True rec_cache = se.CacheRecordingExtractor(rec_sc, return_scaled=False) wav_unscaled = get_unit_waveforms(rec_cache, sort_sc, ms_before=ms_cut, ms_after=ms_cut, save_property_or_features=False, n_jobs=n, memmap=m, return_scaled=False, recompute_info=True) wav_unscaled = [np.array(wf) for wf in wav_unscaled] wav_scaled = get_unit_waveforms(rec_cache, sort_sc, ms_before=ms_cut, ms_after=ms_cut, save_property_or_features=False, n_jobs=n, memmap=m, return_scaled=True, recompute_info=True) wav_scaled = [np.array(wf) for wf in wav_scaled] for (w_unscaled, w_scaled) in zip(wav_unscaled, wav_scaled): assert np.allclose(w_unscaled * gain, w_scaled) # change cut ms wav = get_unit_waveforms(rec, sort, ms_before=2, ms_after=2, save_property_or_features=True, n_jobs=n, memmap=m, recompute_info=True) for (w, w_gt) in zip(wav, waveforms): _, _, samples = w.shape assert np.allclose( w[:, :, samples // 2 - n_wf_samples // 2:samples // 2 + n_wf_samples // 2], w_gt) assert 'waveforms' in sort.get_shared_unit_spike_feature_names() # by group rec.set_channel_groups([0, 0, 1, 1]) wav = get_unit_waveforms(rec, sort, ms_before=ms_cut, ms_after=ms_cut, grouping_property='group', n_jobs=n, memmap=m, recompute_info=True) for (w, w_gt) in zip(wav, waveforms): assert np.allclose(w, w_gt[:, :2]) or np.allclose( w, w_gt[:, 2:]) # test compute_property_from_recordings wav = get_unit_waveforms(rec, sort, ms_before=ms_cut, ms_after=ms_cut, grouping_property='group', compute_property_from_recording=True, n_jobs=n, memmap=m, recompute_info=True) for (w, w_gt) in zip(wav, waveforms): assert np.allclose(w, w_gt[:, :2]) or np.allclose( w, w_gt[:, 2:]) # test max_spikes_per_unit wav = get_unit_waveforms(rec, sort, ms_before=ms_cut, ms_after=ms_cut, max_spikes_per_unit=10, save_property_or_features=False, n_jobs=n, memmap=m, recompute_info=True) for w in wav: assert len(w) <= 10 # test channels wav = get_unit_waveforms(rec, sort, ms_before=ms_cut, ms_after=ms_cut, channel_ids=[0, 1, 2], n_jobs=n, memmap=m, recompute_info=True) for (w, w_gt) in zip(wav, waveforms): assert np.allclose(w, w_gt[:, :3]) shutil.rmtree('test')