def __init__(self, folder_path, exclude_cluster_groups=None, keep_good_only=False): try: import pandas as pd HAVE_PD = True except ImportError: HAVE_PD = False assert HAVE_PD, self.installation_mesg phy_folder = Path(folder_path) spike_times = np.load(phy_folder / 'spike_times.npy') spike_templates = np.load(phy_folder / 'spike_templates.npy') if (phy_folder / 'spike_clusters.npy').is_file(): spike_clusters = np.load(phy_folder / 'spike_clusters.npy') else: spike_clusters = spike_templates clust_id = np.unique(spike_clusters) unit_ids = list(clust_id) spike_times.astype(int) params = read_python(str(phy_folder / 'params.py')) sampling_frequency = params['sample_rate'] # try to load cluster info cluster_info_files = [ p for p in phy_folder.iterdir() if p.suffix in ['.csv', '.tsv'] and "cluster_info" in p.name ] if len(cluster_info_files) == 1: # load properties from cluster_info file cluster_info_file = cluster_info_files[0] if cluster_info_file.suffix == ".tsv": delimeter = "\t" else: delimeter = "," cluster_info = pd.read_csv(cluster_info_file, delimiter=delimeter) else: # load properties from other tsv/csv files all_property_files = [ p for p in phy_folder.iterdir() if p.suffix in ['.csv', '.tsv'] ] cluster_info = None for file in all_property_files: if file.suffix == ".tsv": delimeter = "\t" else: delimeter = "," new_property = pd.read_csv(file, delimiter=delimeter) if cluster_info is None: cluster_info = new_property else: cluster_info = pd.merge(cluster_info, new_property, on='cluster_id') # in case no tsv/csv files are found populate cluster info with minimal info if cluster_info is None: cluster_info = pd.DataFrame({'cluster_id': unit_ids}) cluster_info['group'] = ['unsorted'] * len(unit_ids) if exclude_cluster_groups is not None: if isinstance(exclude_cluster_groups, str): cluster_info = cluster_info.query( f"group != '{exclude_cluster_groups}'") elif isinstance(exclude_cluster_groups, list): if len(exclude_cluster_groups) > 0: for exclude_group in exclude_cluster_groups: cluster_info = cluster_info.query( f"group != '{exclude_group}'") if keep_good_only and "KSLabel" in cluster_info.columns: cluster_info = cluster_info.query(f"KSLabel != 'good'") unit_ids = cluster_info["cluster_id"].values BaseSorting.__init__(self, sampling_frequency, unit_ids) del cluster_info["cluster_id"] for prop_name in cluster_info.columns: if prop_name in ['chan_grp', 'ch_group']: self.set_property(key="group", values=cluster_info[prop_name]) elif prop_name != "group": self.set_property(key=prop_name, values=cluster_info[prop_name]) elif prop_name == "group": # rename group property to 'quality' self.set_property(key="quality", values=cluster_info[prop_name]) self.add_sorting_segment(PhySortingSegment(spike_times, spike_clusters))
def __init__(self, folder_path, sampling_frequency=None, user='******', det_sign='both'): folder_path = Path(folder_path) assert folder_path.is_dir(), 'Folder {} doesn\'t exist'.format( folder_path) if sampling_frequency is None: h5_path = str(folder_path) + '.h5' if Path(h5_path).exists(): with h5py.File(h5_path, mode='r') as f: sampling_frequency = f['sr'][0] # ~ self.set_sampling_frequency(sampling_frequency) det_file = str(folder_path / Path('data_' + folder_path.stem + '.h5')) sort_cat_files = [] for sign in ['neg', 'pos']: if det_sign in ['both', sign]: sort_cat_file = folder_path / Path( 'sort_{}_{}/sort_cat.h5'.format(sign, user)) if sort_cat_file.exists(): sort_cat_files.append((sign, str(sort_cat_file))) unit_counter = 0 spiketrains = {} metadata = {} unsorted = [] with h5py.File(det_file, mode='r') as fdet: for sign, sfile in sort_cat_files: with h5py.File(sfile, mode='r') as f: sp_class = f['classes'][()] gaux = f['groups'][()] groups = { g: gaux[gaux[:, 1] == g, 0] for g in np.unique(gaux[:, 1]) } # array of classes per group group_type = { group: g_type for group, g_type in f['types'][()] } sp_index = f['index'][()] times_css = fdet[sign]['times'][()] for gr, cls in groups.items(): spiketrains[unit_counter] = np.rint( times_css[sp_index[np.isin(sp_class, cls)]] * (sampling_frequency / 1000)) metadata[unit_counter] = {'group_type': group_type[gr]} unit_counter = unit_counter + 1 unit_ids = np.arange(unit_counter, dtype='int64') BaseSorting.__init__(self, sampling_frequency, unit_ids) self.add_sorting_segment(CombinatoSortingSegment(spiketrains)) self.set_property( 'unsorted', np.array( [metadata[u]['group_type'] == 0 for u in range(unit_counter)])) self.set_property( 'artifact', np.array([ metadata[u]['group_type'] == -1 for u in range(unit_counter) ])) self._kwargs = { 'folder_path': str(folder_path), 'user': user, 'det_sign': det_sign }
def __init__(self, file_or_folder_path, exclude_cluster_groups=None): assert HAVE_H5PY, self.installation_mesg # ~ SortingExtractor.__init__(self) kwik_file_or_folder = Path(file_or_folder_path) kwikfile = None klustafolder = None if kwik_file_or_folder.is_file(): assert kwik_file_or_folder.suffix == '.kwik', "Not a '.kwik' file" kwikfile = Path(kwik_file_or_folder).absolute() klustafolder = kwikfile.parent elif kwik_file_or_folder.is_dir(): klustafolder = kwik_file_or_folder kwikfiles = [ f for f in kwik_file_or_folder.iterdir() if f.suffix == '.kwik' ] if len(kwikfiles) == 1: kwikfile = kwikfiles[0] assert kwikfile is not None, "Could not load '.kwik' file" try: config_file = [ f for f in klustafolder.iterdir() if f.suffix == '.prm' ][0] config = read_python(str(config_file)) sampling_frequency = config['traces']['sample_rate'] except Exception as e: print("Could not load sampling frequency info") kf_reader = h5py.File(kwikfile, 'r') spiketrains = [] unit_ids = [] unique_units = [] klusta_units = [] cluster_groups_name = [] groups = [] unit = 0 cs_to_exclude = [] valid_group_names = [ i[1].lower() for i in self.default_cluster_groups.items() ] if exclude_cluster_groups is not None: assert isinstance(exclude_cluster_groups, list), 'exclude_cluster_groups should be a list' for ec in exclude_cluster_groups: assert ec in valid_group_names, f'select exclude names out of: {valid_group_names}' cs_to_exclude.append(ec.lower()) for channel_group in kf_reader.get('/channel_groups'): chan_cluster_id_arr = kf_reader.get( f'/channel_groups/{channel_group}/spikes/clusters/main')[()] chan_cluster_times_arr = kf_reader.get( f'/channel_groups/{channel_group}/spikes/time_samples')[()] chan_cluster_ids = np.unique( chan_cluster_id_arr) # if clusters were merged in gui, # the original id's are still in the kwiktree, but # in this array for cluster_id in chan_cluster_ids: cluster_frame_idx = np.nonzero( chan_cluster_id_arr == cluster_id) # the [()] is a h5py thing st = chan_cluster_times_arr[cluster_frame_idx] assert st.shape[0] > 0, 'no spikes in cluster' cluster_group = kf_reader.get( f'/channel_groups/{channel_group}/clusters/main/{cluster_id}' ).attrs['cluster_group'] assert cluster_group in self.default_cluster_groups.keys( ), f'cluster_group not in "default_dict: {cluster_group}' cluster_group_name = self.default_cluster_groups[cluster_group] if cluster_group_name.lower() in cs_to_exclude: continue spiketrains.append(st) klusta_units.append(int(cluster_id)) unique_units.append(unit) unit += 1 groups.append(int(channel_group)) cluster_groups_name.append(cluster_group_name) if len(np.unique(klusta_units)) == len(np.unique(unique_units)): unit_ids = klusta_units else: print('Klusta units are not unique! Using unique unit ids') unit_ids = unique_units BaseSorting.__init__(self, sampling_frequency, unit_ids) self.is_dumpable = False self.add_sorting_segment(KlustSortingSegment(unit_ids, spiketrains)) self.set_property('group', groups) quality = [e.lower() for e in cluster_groups_name] self.set_property('quality', quality) self._kwargs = { 'file_or_folder_path': str(Path(file_or_folder_path).absolute()) }
def __init__(self, sampling_frequency, unit_ids=[]): BaseSorting.__init__(self, sampling_frequency, unit_ids) self.is_dumpable = False
def __init__(self, file_path, keep_good_only=True): MatlabHelper.__init__(self, file_path) if not self._old_style_mat: _units = self._data['Units'] units = _parse_units(self._data, _units) # Extracting MutliElectrode field by field: _ME = self._data["MultiElectrode"] multi_electrode = dict((k, _ME.get(k)[()]) for k in _ME.keys()) # Extracting sampling_frequency: sr = self._data["samplingRate"] sampling_frequency = float(_squeeze_ds(sr)) # Remove noise units if necessary: if keep_good_only: units = [ unit for unit in units if unit["ID"].flatten()[0].astype(int) % 1000 != 0 ] if 'sortingInfo' in self._data.keys(): info = self._data["sortingInfo"] start_frame = _squeeze_ds(info['startTimes']) self.start_frame = int(start_frame) else: self.start_frame = 0 else: _units = self._getfield('Units').squeeze() fields = _units.dtype.fields.keys() units = [] for unit in _units: unit_dict = {} for f in fields: unit_dict[f] = unit[f] units.append(unit_dict) sr = self._getfield("samplingRate") sampling_frequency = float(_squeeze_ds(sr)) _ME = self._data["MultiElectrode"] multi_electrode = dict( (k, _ME[k][0][0].T) for k in _ME.dtype.fields.keys()) # Remove noise units if necessary: if keep_good_only: units = [ unit for unit in units if unit["ID"].flatten()[0].astype(int) % 1000 != 0 ] if 'sortingInfo' in self._data.keys(): info = self._getfield("sortingInfo") start_frame = _squeeze_ds(info['startTimes']) self.start_frame = int(start_frame) else: self.start_frame = 0 self._units = units self._multi_electrode = multi_electrode unit_ids = [] spiketrains = [] for uc, unit in enumerate(units): unit_id = int(_squeeze_ds(unit["ID"])) spike_times = _squeeze( unit["spikeTrain"]).astype('int64') - self.start_frame unit_ids.append(unit_id) spiketrains.append(spike_times) BaseSorting.__init__(self, sampling_frequency, unit_ids) self.add_sorting_segment(HDSortSortingSegment(unit_ids, spiketrains)) # property templates = [] templates_frames_cut_before = [] for uc, unit in enumerate(units): if self._old_style_mat: template = unit["footprint"].T else: template = unit["footprint"] templates.append(template) templates_frames_cut_before.append(unit["cutLeft"].flatten()) self.set_property("template", np.array(templates)) self.set_property("template_frames_cut_before", np.array(templates_frames_cut_before)) self._kwargs = { 'file_path': str(file_path), 'keep_good_only': keep_good_only }
def __init__(self, folder_path, sampling_frequency=30000): assert self.installed, self.installation_mesg # check correct parent folder: self._folder_path = Path(folder_path) if 'probe' not in self._folder_path.name: raise ValueError( 'folder name should contain "probe", containing channels, clusters.* .npy datasets' ) # load datasets as mmap into a dict: required_alf_datasets = ['spikes.times', 'spikes.clusters'] found_alf_datasets = dict() for alf_dataset_name in self.file_loc.iterdir(): if 'spikes' in alf_dataset_name.stem or 'clusters' in alf_dataset_name.stem: if 'npy' in alf_dataset_name.suffix: dset = np.load(alf_dataset_name, mmap_mode='r', allow_pickle=True) found_alf_datasets.update({alf_dataset_name.stem: dset}) elif 'metrics' in alf_dataset_name.stem: found_alf_datasets.update( {alf_dataset_name.stem: pd.read_csv(alf_dataset_name)}) # check existence of datasets: if not any([i in found_alf_datasets for i in required_alf_datasets]): raise Exception( f'could not find {required_alf_datasets} in folder') spike_clusters = found_alf_datasets['spikes.clusters'] spike_times = found_alf_datasets['spikes.times'] # load units properties: total_units = 0 properties = dict() for alf_dataset_name, alf_dataset in found_alf_datasets.items(): if 'clusters' in alf_dataset_name: if 'clusters.metrics' in alf_dataset_name: for property_name, property_values in found_alf_datasets[ alf_dataset_name].iteritems(): properties[property_name] = property_values.tolist() else: property_name = alf_dataset_name.split('.')[1] properties[property_name] = alf_dataset if total_units == 0: total_units = alf_dataset.shape[0] if 'clusters.metrics' in found_alf_datasets and \ found_alf_datasets['clusters.metrics'].get('cluster_id') is not None: unit_ids = found_alf_datasets['clusters.metrics'].get( 'cluster_id').tolist() else: unit_ids = list(range(total_units)) BaseSorting.__init__(self, unit_ids=unit_ids, sampling_frequency=sampling_frequency) sorting_segment = ALFSortingSegment(spike_clusters, spike_times, sampling_frequency) self.add_sorting_segment(sorting_segment) # add properties for property_name, values in properties.items(): self.set_property(property_name, values) self._kwargs = { 'folder_path': str(Path(folder_path).absolute()), 'sampling_frequency': sampling_frequency }
def __init__(self, file_path: PathType, electrical_series_name: str = None, sampling_frequency: float = None, samples_for_rate_estimation: int = 100000): check_nwb_install() self._file_path = str(file_path) self._electrical_series_name = electrical_series_name io = NWBHDF5IO(self._file_path, mode='r', load_namespaces=True) self._nwbfile = io.read() if sampling_frequency is None: # defines the electrical series from where the sorting came from # important to know the sampling_frequency self._es = get_electrical_series(self._nwbfile, self._electrical_series_name) # get rate timestamps = None if self._es.rate is not None: sampling_frequency = self._es.rate else: if hasattr(self._es, "timestamps"): if self._es.timestamps is not None: timestamps = self._es.timestamps sampling_frequency = 1 / np.median( np.diff(timestamps[samples_for_rate_estimation])) assert sampling_frequency is not None, "Couldn't load sampling frequency. Please provide it with the " \ "'sampling_frequency' argument" # get all units ids units_ids = list(self._nwbfile.units.id[:]) # store units properties and spike features to dictionaries properties = dict() for column in list(self._nwbfile.units.colnames): if column == 'spike_times': continue # if it is unit_property property_values = self._nwbfile.units[column][:] # only load columns with same shape for all units if np.all(p.shape == property_values[0].shape for p in property_values): properties[column] = property_values else: print( f"Skipping {column} because of unequal shapes across units" ) BaseSorting.__init__(self, sampling_frequency=sampling_frequency, unit_ids=units_ids) sorting_segment = NwbSortingSegment( nwbfile=self._nwbfile, sampling_frequency=sampling_frequency, timestamps=timestamps) self.add_sorting_segment(sorting_segment) for prop_name, values in properties.items(): self.set_property(prop_name, np.array(values)) self._kwargs = { 'file_path': str(Path(file_path).absolute()), 'electrical_series_name': self._electrical_series_name, 'sampling_frequency': sampling_frequency, 'samples_for_rate_estimation': samples_for_rate_estimation }
def __init__(self, file_path, electrical_series_name: str = None, sampling_frequency: float = None): """ Parameters ---------- file_path: path to NWB file electrical_series_name: str with pynwb.ecephys.ElectricalSeries object name sampling_frequency: float """ assert self.installed, self.installation_mesg self._file_path = str(file_path) with NWBHDF5IO(self._file_path, 'r') as io: nwbfile = io.read() if sampling_frequency is None: # defines the electrical series from where the sorting came from # important to know the sampling_frequency if electrical_series_name is None: if len(nwbfile.acquisition) > 1: raise Exception( 'More than one acquisition found. You must specify electrical_series_name.' ) if len(nwbfile.acquisition) == 0: raise Exception( "No acquisitions found in the .nwb file from which to read sampling frequency. \ Please, specify 'sampling_frequency' parameter." ) es = list(nwbfile.acquisition.values())[0] else: es = electrical_series_name # get rate if es.rate is not None: sampling_frequency = es.rate else: sampling_frequency = 1 / (es.timestamps[1] - es.timestamps[0]) assert sampling_frequency is not None, "Couldn't load sampling frequency. Please provide it with the " \ "'sampling_frequency' argument" # get all units ids units_ids = list(nwbfile.units.id[:]) # store units properties and spike features to dictionaries all_pr_ft = list(nwbfile.units.colnames) all_names = [i.name for i in nwbfile.units.columns] for item in all_pr_ft: if item == 'spike_times': continue # test if item is a unit_property or a spike_feature if item + '_index' in all_names: # if it has index, it is a spike_feature pass else: # if it is unit_property properties = dict() for u_id in units_ids: ind = list(units_ids).index(u_id) if isinstance(nwbfile.units[item][ind], pd.DataFrame): prop_value = nwbfile.units[item][ind].index[0] else: prop_value = nwbfile.units[item][ind] if item not in properties: properties[item] = np.zeros(len(units_ids), dtype=type(prop_value)) BaseSorting.__init__(self, sampling_frequency=sampling_frequency, units_ids=units_ids) sorting_segment = NwbSortingSegment( path=self._file_path, sampling_frequency=sampling_frequency) self.add_sorting_segment(sorting_segment) for prop_name, values in properties.items(): self.set_property(prop_name, values) self._kwargs = { 'file_path': str(Path(file_path).absolute()), 'electrical_series_name': electrical_series_name, 'sampling_frequency': sampling_frequency }