def __init__(self, sampling_frequency=None, use_natural_unit_ids=False, **neo_kwargs): _NeoBaseExtractor.__init__(self, **neo_kwargs) self.use_natural_unit_ids = use_natural_unit_ids if sampling_frequency is None: sampling_frequency = self._auto_guess_sampling_frequency() spike_channels = self.neo_reader.header['spike_channels'] if use_natural_unit_ids: unit_ids = spike_channels['id'] assert np.unique( unit_ids ).size == unit_ids.size, 'unit_ids is have duplications' else: # use interger based unit_ids unit_ids = np.arange(spike_channels.size, dtype='int64') BaseSorting.__init__(self, sampling_frequency, unit_ids) nseg = self.neo_reader.segment_count(block_index=0) for segment_index in range(nseg): if self.handle_spike_frame_directly: t_start = None else: t_start = self.neo_reader.get_signal_t_start(0, segment_index) sorting_segment = NeoSortingSegment(self.neo_reader, segment_index, self.use_natural_unit_ids, t_start, sampling_frequency) self.add_sorting_segment(sorting_segment)
def __init__(self, file_path, load_unit_info=True): assert self.installed, self.installation_mesg self._recording_file = file_path self._rf = h5py.File(self._recording_file, mode='r') if 'Sampling' in self._rf: if self._rf['Sampling'][()] == 0: sampling_frequency = None else: sampling_frequency = self._rf['Sampling'][()] spike_ids = self._rf['cluster_id'][()] unit_ids = np.unique(spike_ids) spike_times = self._rf['times'][()] if load_unit_info: self.load_unit_info() BaseSorting.__init__(self, sampling_frequency, unit_ids) self.add_sorting_segment( HerdingspikesSortingSegment(unit_ids, spike_times, spike_ids)) self._kwargs = { 'file_path': str(Path(file_path).absolute()), 'load_unit_info': load_unit_info }
def __init__(self, folder_path, chan_grp=None): try: import tridesclous as tdc HAVE_TDC = True except ImportError: HAVE_TDC = False assert HAVE_TDC, self.installation_mesg tdc_folder = Path(folder_path) dataio = tdc.DataIO(str(tdc_folder)) if chan_grp is None: # if chan_grp is not provided, take the first one if unique chan_grps = list(dataio.channel_groups.keys()) assert len(chan_grps) == 1, 'There are several groups in the folder, specify chan_grp=...' chan_grp = chan_grps[0] catalogue = dataio.load_catalogue(name='initial', chan_grp=chan_grp) labels = catalogue['clusters']['cluster_label'] labels = labels[labels >= 0] unit_ids = list(labels) sampling_frequency = dataio.sample_rate BaseSorting.__init__(self, sampling_frequency, unit_ids) for seg_num in range(dataio.nb_segment): # load all spike in memory (this avoid to lock the folder with memmap throug dataio all_spikes = dataio.get_spikes(seg_num=seg_num, chan_grp=chan_grp, i_start=None, i_stop=None).copy() self.add_sorting_segment(TridesclousSortingSegment(all_spikes)) self._kwargs = {'folder_path': str(Path(folder_path).absolute()), 'chan_grp': chan_grp}
def __init__(self, oldapi_sorting_extractor): BaseSorting.__init__(self, sampling_frequency=oldapi_sorting_extractor. get_sampling_frequency(), unit_ids=oldapi_sorting_extractor.get_unit_ids()) sorting_segment = OldToNewSortingSegment(oldapi_sorting_extractor) self.add_sorting_segment(sorting_segment)
def __init__(self, folder_path): assert HAVE_H5PY, self.installation_mesg spykingcircus_folder = Path(folder_path) listfiles = spykingcircus_folder.iterdir() parent_folder = None result_folder = None for f in listfiles: if f.is_dir(): if any([f_.suffix == '.hdf5' for f_ in f.iterdir()]): parent_folder = spykingcircus_folder result_folder = f if parent_folder is None: parent_folder = spykingcircus_folder.parent for f in parent_folder.iterdir(): if f.is_dir(): if any([f_.suffix == '.hdf5' for f_ in f.iterdir()]): result_folder = spykingcircus_folder assert isinstance(parent_folder, Path) and isinstance( result_folder, Path), "Not a valid spyking circus folder" # load files results = None for f in result_folder.iterdir(): if 'result.hdf5' in str(f): results = f if 'result-merged.hdf5' in str(f): results = f break if results is None: raise Exception(spykingcircus_folder, " is not a spyking circus folder") # load params sample_rate = None for f in parent_folder.iterdir(): if f.suffix == '.params': sample_rate = _load_sample_rate(f) assert sample_rate is not None, 'sample rate not found' with h5py.File(results, 'r') as f_results: spiketrains = [] unit_ids = [] for temp in f_results['spiketimes'].keys(): spiketrains.append( np.array(f_results['spiketimes'][temp]).astype('int64')) unit_ids.append(int(temp.split('_')[-1])) BaseSorting.__init__(self, sample_rate, unit_ids) self.add_sorting_segment( SpykingcircustSortingSegment(unit_ids, spiketrains)) self._kwargs = {'folder_path': str(Path(folder_path).absolute())}
def __init__(self, file_path, sampling_frequency): firings = readmda(str(file_path)) labels = firings[2, :] unit_ids = np.unique(labels).astype(int) BaseSorting.__init__(self, unit_ids=unit_ids, sampling_frequency=sampling_frequency) sorting_segment = MdaSortingSegment(firings) self.add_sorting_segment(sorting_segment) self._kwargs = { 'file_path': str(Path(file_path).absolute()), 'sampling_frequency': sampling_frequency }
def __init__(self, oldapi_sorting_extractor): BaseSorting.__init__(self, sampling_frequency=oldapi_sorting_extractor. get_sampling_frequency(), unit_ids=oldapi_sorting_extractor.get_unit_ids()) sorting_segment = OldToNewSortingSegment(oldapi_sorting_extractor) self.add_sorting_segment(sorting_segment) self.is_dumpable = False # add old properties copy_properties(oldapi_extractor=oldapi_sorting_extractor, new_extractor=self) self._kwargs = {'oldapi_sorting_extractor': oldapi_sorting_extractor}
def __init__(self, folder_path, sampling_frequency=None, user='******', det_sign='both', keep_good_only=True): folder_path = Path(folder_path) assert folder_path.is_dir(), 'Folder {} doesn\'t exist'.format(folder_path) if sampling_frequency is None: h5_path = str(folder_path) + '.h5' if Path(h5_path).exists(): with h5py.File(h5_path, mode='r') as f: sampling_frequency = f['sr'][0] # ~ self.set_sampling_frequency(sampling_frequency) det_file = str(folder_path / Path('data_' + folder_path.stem + '.h5')) sort_cat_files = [] for sign in ['neg', 'pos']: if det_sign in ['both', sign]: sort_cat_file = folder_path / Path('sort_{}_{}/sort_cat.h5'.format(sign, user)) if sort_cat_file.exists(): sort_cat_files.append((sign, str(sort_cat_file))) unit_counter = 0 spiketrains = {} metadata = {} unsorted = [] with h5py.File(det_file, mode='r') as fdet: for sign, sfile in sort_cat_files: with h5py.File(sfile, mode='r') as f: sp_class = f['classes'][()] gaux = f['groups'][()] groups = {g: gaux[gaux[:, 1] == g, 0] for g in np.unique(gaux[:, 1])} # array of classes per group group_type = {group: g_type for group, g_type in f['types'][()]} sp_index = f['index'][()] times_css = fdet[sign]['times'][()] for gr, cls in groups.items(): if keep_good_only and (group_type[gr] < 1): # artifact or unsorted continue spiketrains[unit_counter] = np.rint( times_css[sp_index[np.isin(sp_class, cls)]] * (sampling_frequency / 1000)) metadata[unit_counter] = {'group_type': group_type[gr]} unit_counter = unit_counter + 1 unit_ids = np.arange(unit_counter, dtype='int64') BaseSorting.__init__(self, sampling_frequency, unit_ids) self.add_sorting_segment(CombinatoSortingSegment(spiketrains)) self.set_property('unsorted', np.array([metadata[u]['group_type'] == 0 for u in range(unit_counter)])) self.set_property('artifact', np.array([metadata[u]['group_type'] == -1 for u in range(unit_counter)])) self._kwargs = {'folder_path': str(folder_path), 'user': user, 'det_sign': det_sign}
def __init__(self, sampling_frequency, multisortingcomparison, min_agreement_count=1, min_agreement_count_only=False): self._msc = multisortingcomparison self.is_dumpable = False # TODO: @alessio I leav this for you # if min_agreement_count_only: # self._unit_ids = list(u for u in self._msc._new_units.keys() # if self._msc._new_units[u]['agreement_number'] == min_agreement_count) # else: # self._unit_ids = list(u for u in self._msc._new_units.keys() # if self._msc._new_units[u]['agreement_number'] >= min_agreement_count) # for unit in self._unit_ids: # self.set_unit_property(unit_id=unit, property_name='agreement_number', # value=self._msc._new_units[unit]['agreement_number']) # self.set_unit_property(unit_id=unit, property_name='avg_agreement', # value=self._msc._new_units[unit]['avg_agreement']) # self.set_unit_property(unit_id=unit, property_name='sorter_unit_ids', # value=self._msc._new_units[unit]['sorter_unit_ids']) if min_agreement_count_only: unit_ids = list(u for u in self._msc._new_units.keys() if self._msc._new_units[u]['agreement_number'] == min_agreement_count) else: unit_ids = list(u for u in self._msc._new_units.keys() if self._msc._new_units[u]['agreement_number'] >= min_agreement_count) BaseSorting.__init__(self, sampling_frequency=sampling_frequency, unit_ids=unit_ids) if len(unit_ids) > 0: for k in ('agreement_number', 'avg_agreement', 'sorter_unit_ids'): values = [ self._msc._new_units[unit_id][k] for unit_id in unit_ids ] self.set_property(k, values, ids=unit_ids) sorting_segment = AgreementSortingSegment(multisortingcomparison) self.add_sorting_segment(sorting_segment)
def __init__(self, file_path, keep_good_only=True): MatlabHelper.__init__(self, file_path) cluster_classes = self._getfield("cluster_class") classes = cluster_classes[:, 0] spike_times = cluster_classes[:, 1] par = self._getfield("par") sampling_frequency = par[0, 0][np.where(np.array(par.dtype.names) == 'sr')[0][0]][0][0] unit_ids = np.unique(classes).astype('int') if keep_good_only: unit_ids = unit_ids[unit_ids > 0] spiketrains = {} for unit_id in unit_ids: mask = (classes == unit_id) spiketrains[unit_id] = np.rint(spike_times[mask] * (sampling_frequency / 1000)) BaseSorting.__init__(self, sampling_frequency, unit_ids) self.add_sorting_segment(WaveClustSortingSegment(unit_ids, spiketrains)) self.set_property('unsorted', np.array([c == 0 for c in unit_ids])) self._kwargs = {'file_path': str(Path(file_path).absolute())}
def __init__(self, folder_path): assert HAVE_YASS, self.installation_mesg folder_path = Path(folder_path) self.fname_spike_train = folder_path / 'tmp' / 'output' / 'spike_train.npy' self.fname_templates = folder_path / 'tmp' / 'output' / 'templates' / 'templates_0sec.npy' self.fname_config = folder_path / 'config.yaml' # Read CONFIG File with open(self.fname_config, 'r') as stream: self.config = yaml.safe_load(stream) spiketrains = np.load(self.fname_spike_train) unit_ids = np.unique(spiketrains[:, 1]) # initialize sampling_frequency = self.config['recordings']['sampling_rate'] BaseSorting.__init__(self, sampling_frequency, unit_ids) self.add_sorting_segment(YassSortingSegment(spiketrains)) self._kwargs = {'folder_path': str(folder_path)}
def __init__(self, sampling_frequency, multisortingcomparison, min_agreement_count=1, min_agreement_count_only=False): self._msc = multisortingcomparison self.is_dumpable = False if min_agreement_count_only: unit_ids = list(u for u in self._msc._new_units.keys() if self._msc._new_units[u]['agreement_number'] == min_agreement_count) else: unit_ids = list(u for u in self._msc._new_units.keys() if self._msc._new_units[u]['agreement_number'] >= min_agreement_count) BaseSorting.__init__( self, sampling_frequency=sampling_frequency, unit_ids=unit_ids) if len(unit_ids) > 0: for k in ('agreement_number', 'avg_agreement', 'unit_ids'): values = [self._msc._new_units[unit_id][k] for unit_id in unit_ids] self.set_property(k, values, ids=unit_ids) sorting_segment = AgreementSortingSegment(multisortingcomparison) self.add_sorting_segment(sorting_segment)
def __init__(self, file_path, sampling_frequency, delimiter=','): assert self.installed, self.installation_mesg assert Path(file_path ).suffix == ".csv", "The 'file_path' should be a csv file!" if Path(file_path).is_file(): spike_clusters = sbio.SpikeClusters() spike_clusters.fromCSV(str(file_path), None, delimiter=delimiter) else: raise FileNotFoundError( f'The ground truth file {file_path} could not be found') BaseSorting.__init__(self, unit_ids=spike_clusters.keys(), sampling_frequency=sampling_frequency) sorting_segment = SHYBRIDSortingSegment(spike_clusters) self.add_sorting_segment(sorting_segment) self._kwargs = { 'file_path': str(Path(file_path).absolute()), 'sampling_frequency': sampling_frequency, 'delimiter': delimiter }
def __init__(self, file_or_folder_path, exclude_cluster_groups=None): assert HAVE_H5PY, self.installation_mesg # ~ SortingExtractor.__init__(self) kwik_file_or_folder = Path(file_or_folder_path) kwikfile = None klustafolder = None if kwik_file_or_folder.is_file(): assert kwik_file_or_folder.suffix == '.kwik', "Not a '.kwik' file" kwikfile = Path(kwik_file_or_folder).absolute() klustafolder = kwikfile.parent elif kwik_file_or_folder.is_dir(): klustafolder = kwik_file_or_folder kwikfiles = [ f for f in kwik_file_or_folder.iterdir() if f.suffix == '.kwik' ] if len(kwikfiles) == 1: kwikfile = kwikfiles[0] assert kwikfile is not None, "Could not load '.kwik' file" try: config_file = [ f for f in klustafolder.iterdir() if f.suffix == '.prm' ][0] config = read_python(str(config_file)) sampling_frequency = config['traces']['sample_rate'] except Exception as e: print("Could not load sampling frequency info") kf_reader = h5py.File(kwikfile, 'r') spiketrains = [] unit_ids = [] unique_units = [] klusta_units = [] cluster_groups_name = [] groups = [] unit = 0 cs_to_exclude = [] valid_group_names = [ i[1].lower() for i in self.default_cluster_groups.items() ] if exclude_cluster_groups is not None: assert isinstance(exclude_cluster_groups, list), 'exclude_cluster_groups should be a list' for ec in exclude_cluster_groups: assert ec in valid_group_names, f'select exclude names out of: {valid_group_names}' cs_to_exclude.append(ec.lower()) for channel_group in kf_reader.get('/channel_groups'): chan_cluster_id_arr = kf_reader.get( f'/channel_groups/{channel_group}/spikes/clusters/main')[()] chan_cluster_times_arr = kf_reader.get( f'/channel_groups/{channel_group}/spikes/time_samples')[()] chan_cluster_ids = np.unique( chan_cluster_id_arr) # if clusters were merged in gui, # the original id's are still in the kwiktree, but # in this array for cluster_id in chan_cluster_ids: cluster_frame_idx = np.nonzero( chan_cluster_id_arr == cluster_id) # the [()] is a h5py thing st = chan_cluster_times_arr[cluster_frame_idx] assert st.shape[0] > 0, 'no spikes in cluster' cluster_group = kf_reader.get( f'/channel_groups/{channel_group}/clusters/main/{cluster_id}' ).attrs['cluster_group'] assert cluster_group in self.default_cluster_groups.keys( ), f'cluster_group not in "default_dict: {cluster_group}' cluster_group_name = self.default_cluster_groups[cluster_group] if cluster_group_name.lower() in cs_to_exclude: continue spiketrains.append(st) klusta_units.append(int(cluster_id)) unique_units.append(unit) unit += 1 groups.append(int(channel_group)) cluster_groups_name.append(cluster_group_name) if len(np.unique(klusta_units)) == len(np.unique(unique_units)): unit_ids = klusta_units else: print('Klusta units are not unique! Using unique unit ids') unit_ids = unique_units BaseSorting.__init__(self, sampling_frequency, unit_ids) self.is_dumpable = False self.add_sorting_segment(KlustSortingSegment(unit_ids, spiketrains)) self.set_property('group', groups) quality = [e.lower() for e in cluster_groups_name] self.set_property('quality', quality) self._kwargs = { 'file_or_folder_path': str(Path(file_or_folder_path).absolute()) }
def __init__(self, file_path, electrical_series_name: str = None, sampling_frequency: float = None): """ Parameters ---------- file_path: path to NWB file electrical_series_name: str with pynwb.ecephys.ElectricalSeries object name sampling_frequency: float """ assert self.installed, self.installation_mesg self._file_path = str(file_path) with NWBHDF5IO(self._file_path, 'r') as io: nwbfile = io.read() if sampling_frequency is None: # defines the electrical series from where the sorting came from # important to know the sampling_frequency if electrical_series_name is None: if len(nwbfile.acquisition) > 1: raise Exception( 'More than one acquisition found. You must specify electrical_series_name.' ) if len(nwbfile.acquisition) == 0: raise Exception( "No acquisitions found in the .nwb file from which to read sampling frequency. \ Please, specify 'sampling_frequency' parameter." ) es = list(nwbfile.acquisition.values())[0] else: es = electrical_series_name # get rate if es.rate is not None: sampling_frequency = es.rate else: sampling_frequency = 1 / (es.timestamps[1] - es.timestamps[0]) assert sampling_frequency is not None, "Couldn't load sampling frequency. Please provide it with the " \ "'sampling_frequency' argument" # get all units ids units_ids = list(nwbfile.units.id[:]) # store units properties and spike features to dictionaries all_pr_ft = list(nwbfile.units.colnames) all_names = [i.name for i in nwbfile.units.columns] for item in all_pr_ft: if item == 'spike_times': continue # test if item is a unit_property or a spike_feature if item + '_index' in all_names: # if it has index, it is a spike_feature pass else: # if it is unit_property properties = dict() for u_id in units_ids: ind = list(units_ids).index(u_id) if isinstance(nwbfile.units[item][ind], pd.DataFrame): prop_value = nwbfile.units[item][ind].index[0] else: prop_value = nwbfile.units[item][ind] if item not in properties: properties[item] = np.zeros(len(units_ids), dtype=type(prop_value)) BaseSorting.__init__(self, sampling_frequency=sampling_frequency, units_ids=units_ids) sorting_segment = NwbSortingSegment( path=self._file_path, sampling_frequency=sampling_frequency) self.add_sorting_segment(sorting_segment) for prop_name, values in properties.items(): self.set_property(prop_name, values) self._kwargs = { 'file_path': str(Path(file_path).absolute()), 'electrical_series_name': electrical_series_name, 'sampling_frequency': sampling_frequency }
def __init__(self, sampling_frequency, unit_ids=[]): BaseSorting.__init__(self, sampling_frequency, unit_ids) self.is_dumpable = False
def __init__(self, folder_path, exclude_cluster_groups=None, keep_good_only=False): try: import pandas as pd HAVE_PD = True except ImportError: HAVE_PD = False assert HAVE_PD, self.installation_mesg phy_folder = Path(folder_path) spike_times = np.load(phy_folder / 'spike_times.npy') spike_templates = np.load(phy_folder / 'spike_templates.npy') if (phy_folder / 'spike_clusters.npy').is_file(): spike_clusters = np.load(phy_folder / 'spike_clusters.npy') else: spike_clusters = spike_templates clust_id = np.unique(spike_clusters) unit_ids = list(clust_id) spike_times.astype(int) params = read_python(str(phy_folder / 'params.py')) sampling_frequency = params['sample_rate'] # try to load cluster info cluster_info_files = [ p for p in phy_folder.iterdir() if p.suffix in ['.csv', '.tsv'] and "cluster_info" in p.name ] if len(cluster_info_files) == 1: # load properties from cluster_info file cluster_info_file = cluster_info_files[0] if cluster_info_file.suffix == ".tsv": delimeter = "\t" else: delimeter = "," cluster_info = pd.read_csv(cluster_info_file, delimiter=delimeter) else: # load properties from other tsv/csv files all_property_files = [ p for p in phy_folder.iterdir() if p.suffix in ['.csv', '.tsv'] ] cluster_info = None for file in all_property_files: if file.suffix == ".tsv": delimeter = "\t" else: delimeter = "," new_property = pd.read_csv(file, delimiter=delimeter) if cluster_info is None: cluster_info = new_property else: cluster_info = pd.merge(cluster_info, new_property, on='cluster_id') # in case no tsv/csv files are found populate cluster info with minimal info if cluster_info is None: cluster_info = pd.DataFrame({'cluster_id': unit_ids}) cluster_info['group'] = ['unsorted'] * len(unit_ids) if exclude_cluster_groups is not None: if isinstance(exclude_cluster_groups, str): cluster_info = cluster_info.query( f"group != '{exclude_cluster_groups}'") elif isinstance(exclude_cluster_groups, list): if len(exclude_cluster_groups) > 0: for exclude_group in exclude_cluster_groups: cluster_info = cluster_info.query( f"group != '{exclude_group}'") if keep_good_only and "KSLabel" in cluster_info.columns: cluster_info = cluster_info.query("KSLabel == 'good'") if "cluster_id" not in cluster_info.columns: assert "id" in cluster_info.columns, "Couldn't find cluster ids in the tsv files!" cluster_info["cluster_id"] = cluster_info["id"] del cluster_info["id"] if 'si_unit_id' in cluster_info.columns: unit_ids = cluster_info["si_unit_id"].values del cluster_info["si_unit_id"] else: unit_ids = cluster_info["cluster_id"].values BaseSorting.__init__(self, sampling_frequency, unit_ids) del cluster_info["cluster_id"] for prop_name in cluster_info.columns: if prop_name in ['chan_grp', 'ch_group']: self.set_property(key="group", values=cluster_info[prop_name]) elif prop_name != "group": self.set_property(key=prop_name, values=cluster_info[prop_name]) elif prop_name == "group": # rename group property to 'quality' self.set_property(key="quality", values=cluster_info[prop_name]) self.add_sorting_segment(PhySortingSegment(spike_times, spike_clusters))
def __init__(self, file_path: PathType, electrical_series_name: str = None, sampling_frequency: float = None, samples_for_rate_estimation: int = 100000): check_nwb_install() self._file_path = str(file_path) self._electrical_series_name = electrical_series_name io = NWBHDF5IO(self._file_path, mode='r', load_namespaces=True) self._nwbfile = io.read() if sampling_frequency is None: # defines the electrical series from where the sorting came from # important to know the sampling_frequency self._es = get_electrical_series(self._nwbfile, self._electrical_series_name) # get rate timestamps = None if self._es.rate is not None: sampling_frequency = self._es.rate else: if hasattr(self._es, "timestamps"): if self._es.timestamps is not None: timestamps = self._es.timestamps sampling_frequency = 1 / np.median( np.diff(timestamps[samples_for_rate_estimation])) assert sampling_frequency is not None, "Couldn't load sampling frequency. Please provide it with the " \ "'sampling_frequency' argument" # get all units ids units_ids = list(self._nwbfile.units.id[:]) # store units properties and spike features to dictionaries properties = dict() for column in list(self._nwbfile.units.colnames): if column == 'spike_times': continue # if it is unit_property property_values = self._nwbfile.units[column][:] # only load columns with same shape for all units if np.all(p.shape == property_values[0].shape for p in property_values): properties[column] = property_values else: print( f"Skipping {column} because of unequal shapes across units" ) BaseSorting.__init__(self, sampling_frequency=sampling_frequency, unit_ids=units_ids) sorting_segment = NwbSortingSegment( nwbfile=self._nwbfile, sampling_frequency=sampling_frequency, timestamps=timestamps) self.add_sorting_segment(sorting_segment) for prop_name, values in properties.items(): self.set_property(prop_name, np.array(values)) self._kwargs = { 'file_path': str(Path(file_path).absolute()), 'electrical_series_name': self._electrical_series_name, 'sampling_frequency': sampling_frequency, 'samples_for_rate_estimation': samples_for_rate_estimation }
def __init__(self, folder_path, sampling_frequency=30000): assert self.installed, self.installation_mesg # check correct parent folder: self._folder_path = Path(folder_path) if 'probe' not in self._folder_path.name: raise ValueError( 'folder name should contain "probe", containing channels, clusters.* .npy datasets' ) # load datasets as mmap into a dict: required_alf_datasets = ['spikes.times', 'spikes.clusters'] found_alf_datasets = dict() for alf_dataset_name in self.file_loc.iterdir(): if 'spikes' in alf_dataset_name.stem or 'clusters' in alf_dataset_name.stem: if 'npy' in alf_dataset_name.suffix: dset = np.load(alf_dataset_name, mmap_mode='r', allow_pickle=True) found_alf_datasets.update({alf_dataset_name.stem: dset}) elif 'metrics' in alf_dataset_name.stem: found_alf_datasets.update( {alf_dataset_name.stem: pd.read_csv(alf_dataset_name)}) # check existence of datasets: if not any([i in found_alf_datasets for i in required_alf_datasets]): raise Exception( f'could not find {required_alf_datasets} in folder') spike_clusters = found_alf_datasets['spikes.clusters'] spike_times = found_alf_datasets['spikes.times'] # load units properties: total_units = 0 properties = dict() for alf_dataset_name, alf_dataset in found_alf_datasets.items(): if 'clusters' in alf_dataset_name: if 'clusters.metrics' in alf_dataset_name: for property_name, property_values in found_alf_datasets[ alf_dataset_name].iteritems(): properties[property_name] = property_values.tolist() else: property_name = alf_dataset_name.split('.')[1] properties[property_name] = alf_dataset if total_units == 0: total_units = alf_dataset.shape[0] if 'clusters.metrics' in found_alf_datasets and \ found_alf_datasets['clusters.metrics'].get('cluster_id') is not None: unit_ids = found_alf_datasets['clusters.metrics'].get( 'cluster_id').tolist() else: unit_ids = list(range(total_units)) BaseSorting.__init__(self, unit_ids=unit_ids, sampling_frequency=sampling_frequency) sorting_segment = ALFSortingSegment(spike_clusters, spike_times, sampling_frequency) self.add_sorting_segment(sorting_segment) # add properties for property_name, values in properties.items(): self.set_property(property_name, values) self._kwargs = { 'folder_path': str(Path(folder_path).absolute()), 'sampling_frequency': sampling_frequency }
def __init__(self, file_path, keep_good_only=True): MatlabHelper.__init__(self, file_path) if not self._old_style_mat: _units = self._data['Units'] units = _parse_units(self._data, _units) # Extracting MutliElectrode field by field: _ME = self._data["MultiElectrode"] multi_electrode = dict((k, _ME.get(k)[()]) for k in _ME.keys()) # Extracting sampling_frequency: sr = self._data["samplingRate"] sampling_frequency = float(_squeeze_ds(sr)) # Remove noise units if necessary: if keep_good_only: units = [ unit for unit in units if unit["ID"].flatten()[0].astype(int) % 1000 != 0 ] if 'sortingInfo' in self._data.keys(): info = self._data["sortingInfo"] start_frame = _squeeze_ds(info['startTimes']) self.start_frame = int(start_frame) else: self.start_frame = 0 else: _units = self._getfield('Units').squeeze() fields = _units.dtype.fields.keys() units = [] for unit in _units: unit_dict = {} for f in fields: unit_dict[f] = unit[f] units.append(unit_dict) sr = self._getfield("samplingRate") sampling_frequency = float(_squeeze_ds(sr)) _ME = self._data["MultiElectrode"] multi_electrode = dict( (k, _ME[k][0][0].T) for k in _ME.dtype.fields.keys()) # Remove noise units if necessary: if keep_good_only: units = [ unit for unit in units if unit["ID"].flatten()[0].astype(int) % 1000 != 0 ] if 'sortingInfo' in self._data.keys(): info = self._getfield("sortingInfo") start_frame = _squeeze_ds(info['startTimes']) self.start_frame = int(start_frame) else: self.start_frame = 0 self._units = units self._multi_electrode = multi_electrode unit_ids = [] spiketrains = [] for uc, unit in enumerate(units): unit_id = int(_squeeze_ds(unit["ID"])) spike_times = _squeeze( unit["spikeTrain"]).astype('int64') - self.start_frame unit_ids.append(unit_id) spiketrains.append(spike_times) BaseSorting.__init__(self, sampling_frequency, unit_ids) self.add_sorting_segment(HDSortSortingSegment(unit_ids, spiketrains)) # property templates = [] templates_frames_cut_before = [] for uc, unit in enumerate(units): if self._old_style_mat: template = unit["footprint"].T else: template = unit["footprint"] templates.append(template) templates_frames_cut_before.append(unit["cutLeft"].flatten()) self.set_property("template", np.array(templates)) self.set_property("template_frames_cut_before", np.array(templates_frames_cut_before)) self._kwargs = { 'file_path': str(file_path), 'keep_good_only': keep_good_only }
def __init__(self, folder_path, exclude_cluster_groups=None, keep_good_only=False): try: import pandas as pd HAVE_PD = True except ImportError: HAVE_PD = False assert HAVE_PD, self.installation_mesg phy_folder = Path(folder_path) spike_times = np.load(phy_folder / 'spike_times.npy') spike_templates = np.load(phy_folder / 'spike_templates.npy') if (phy_folder / 'spike_clusters.npy').is_file(): spike_clusters = np.load(phy_folder / 'spike_clusters.npy') else: spike_clusters = spike_templates clust_id = np.unique(spike_clusters) unit_ids = list(clust_id) spike_times.astype(int) params = read_python(str(phy_folder / 'params.py')) sampling_frequency = params['sample_rate'] # try to load cluster info cluster_info_files = [ p for p in phy_folder.iterdir() if p.suffix in ['.csv', '.tsv'] and "cluster_info" in p.name ] if len(cluster_info_files) == 1: cluster_info_file = cluster_info_files[0] if cluster_info_file.suffix == ".tsv": delimeter = "\t" else: delimeter = "," cluster_info = pd.read_csv(cluster_info_file, delimiter=delimeter) else: all_property_files = [ p for p in phy_folder.iterdir() if p.suffix in ['.csv', '.tsv'] ] cluster_info = None for file in all_property_files: if file.suffix == ".tsv": delimeter = "\t" else: delimeter = "," new_property = pd.read_csv(file, delimiter=delimeter) if cluster_info is None: cluster_info = new_property else: cluster_info = pd.merge(cluster_info, new_property, on='cluster_id') cluster_info["id"] = cluster_info["cluster_id"] del cluster_info["cluster_id"] if exclude_cluster_groups is not None: if isinstance(exclude_cluster_groups, str): cluster_info = cluster_info.query( f"group != '{exclude_cluster_groups}'") elif isinstance(exclude_cluster_groups, list): if len(exclude_cluster_groups) > 0: for exclude_group in exclude_cluster_groups: cluster_info = cluster_info.query( f"group != '{exclude_group}'") if keep_good_only and "KSLabel" in cluster_info.columns: cluster_info = cluster_info.query(f"KSLabel != 'good'") unit_ids = cluster_info["id"].values BaseSorting.__init__(self, sampling_frequency, unit_ids) for prop_name in cluster_info.columns: if prop_name in ['chan_grp', 'ch_group']: self.set_property(key="group", values=cluster_info[prop_name]) elif prop_name != "group": self.set_property(key=prop_name, values=cluster_info[prop_name]) self.add_sorting_segment(PhySortingSegment(spike_times, spike_clusters)) self._kwargs = { 'folder_path': str(Path(folder_path).absolute()), 'exclude_cluster_groups': exclude_cluster_groups }