def __init__(self, folder_path): RecordingExtractor.__init__(self) phy_folder = Path(folder_path) self.params = read_python(str(phy_folder / 'params.py')) datfile = [x for x in phy_folder.iterdir() if x.suffix == '.dat' or x.suffix == '.bin'] if (phy_folder / 'channel_map_si.npy').is_file(): channel_map = list(np.squeeze(np.load(phy_folder / 'channel_map_si.npy'))) assert len(channel_map) == self.params['n_channels_dat'] elif (phy_folder / 'channel_map.npy').is_file(): channel_map = list(np.squeeze(np.load(phy_folder / 'channel_map.npy'))) assert len(channel_map) == self.params['n_channels_dat'] else: channel_map = list(range(self.params['n_channels_dat'])) BinDatRecordingExtractor.__init__(self, datfile[0], sampling_frequency=float(self.params['sample_rate']), dtype=self.params['dtype'], numchan=self.params['n_channels_dat'], recording_channels=list(channel_map)) if (phy_folder / 'channel_groups.npy').is_file(): channel_groups = np.load(phy_folder / 'channel_groups.npy') assert len(channel_groups) == self.get_num_channels() self.set_channel_groups(channel_groups) if (phy_folder / 'channel_positions.npy').is_file(): channel_locations = np.load(phy_folder / 'channel_positions.npy') assert len(channel_locations) == self.get_num_channels() self.set_channel_locations(channel_locations) self._kwargs = {'folder_path': str(Path(folder_path).absolute())}
def __init__(self, dir_path): RecordingExtractor.__init__(self) phy_folder = Path(dir_path) self.params = read_python(str(phy_folder / 'params.py')) datfile = [x for x in phy_folder.iterdir() if x.suffix == '.dat' or x.suffix == '.bin'] if (phy_folder / 'channel_map_si.npy').is_file(): channel_map = list(np.squeeze(np.load(phy_folder / 'channel_map_si.npy'))) assert len(channel_map) == self.params['n_channels_dat'] elif (phy_folder / 'channel_map.npy').is_file(): channel_map = list(np.squeeze(np.load(phy_folder / 'channel_map.npy'))) assert len(channel_map) == self.params['n_channels_dat'] else: channel_map = list(range(self.params['n_channels_dat'])) BinDatRecordingExtractor.__init__(self, datfile[0], samplerate=float(self.params['sample_rate']), dtype=self.params['dtype'], numchan=self.params['n_channels_dat'], recording_channels=list(channel_map)) if (phy_folder / 'channel_groups.npy').is_file(): channel_groups = np.load(phy_folder / 'channel_groups.npy') assert len(channel_groups) == self.get_num_channels() for (ch, cg) in zip(self.get_channel_ids(), channel_groups): self.set_channel_property(ch, 'group', cg) if (phy_folder / 'channel_positions.npy').is_file(): channel_locations = np.load(phy_folder / 'channel_positions.npy') assert len(channel_locations) == self.get_num_channels() for (ch, loc) in zip(self.get_channel_ids(), channel_locations): self.set_channel_property(ch, 'location', loc)
def __init__(self, kwik_file_or_folder): assert HAVE_KLSX, "To use the KlustaSortingExtractor install h5py: \n\n pip install h5py\n\n" SortingExtractor.__init__(self) kwik_file_or_folder = Path(kwik_file_or_folder) kwikfile = None klustafolder = None if kwik_file_or_folder.is_file(): assert kwik_file_or_folder.suffix == '.kwik', "Not a '.kwik' file" kwikfile = Path(kwik_file_or_folder).absolute() klustafolder = kwikfile.parent elif kwik_file_or_folder.is_dir(): klustafolder = kwik_file_or_folder kwikfiles = [f for f in kwik_file_or_folder.iterdir() if f.suffix == '.kwik'] if len(kwikfiles) == 1: kwikfile = kwikfiles[0] assert kwikfile is not None, "Could not load '.kwik' file" try: config_file = [f for f in klustafolder.iterdir() if f.suffix == '.prm'][0] config = read_python(str(config_file)) sample_rate = config['traces']['sample_rate'] self._sampling_frequency = sample_rate except Exception as e: print("Could not load sampling frequency info") F = h5py.File(kwikfile) channel_groups = F.get('channel_groups') self._spiketrains = [] self._unit_ids = [] unique_units = [] klusta_units = [] groups = [] unit = 0 for cgroup in channel_groups: group_id = int(cgroup) try: cluster_ids = channel_groups[cgroup]['clusters']['main'] except Exception as e: print('Unable to extract clusters from', kwikfile) continue for cluster_id in channel_groups[cgroup]['clusters']['main']: clusters = np.array(channel_groups[cgroup]['spikes']['clusters']['main']) idx = np.nonzero(clusters == int(cluster_id)) st = np.array(channel_groups[cgroup]['spikes']['time_samples'])[idx] self._spiketrains.append(st) klusta_units.append(int(cluster_id)) unique_units.append(unit) unit += 1 groups.append(group_id) if len(np.unique(klusta_units)) == len(np.unique(unique_units)): self._unit_ids = klusta_units else: print('Klusta units are not unique! Using unique unit ids') self._unit_ids = unique_units for i, u in enumerate(self._unit_ids): self.set_unit_property(u, 'group', groups[i])
def __init__(self, klustafolder): assert HAVE_KLSX, "To use the KlustaSortingExtractor install h5py: \n\n pip install h5py\n\n" klustafolder = Path(klustafolder).absolute() config_file = [f for f in klustafolder.iterdir() if f.suffix == '.prm'][0] dat_file = [f for f in klustafolder.iterdir() if f.suffix == '.dat'][0] assert config_file.is_file() and dat_file.is_file(), "Not a valid klusta folder" config = read_python(str(config_file)) sample_rate = config['traces']['sample_rate'] n_channels = config['traces']['n_channels'] dtype = config['traces']['dtype'] BinDatRecordingExtractor.__init__(self, datfile=dat_file, samplerate=sample_rate, numchan=n_channels, dtype=dtype)
def load_probe_file_inplace(recording, probe_file): ''' This is a locally modified version of spikeextractor.extraction_tools.load_probe_file. But it do load "in place" and do NOT return a SubRecordingExtractor. This is usefull to not copy local raw data for KS, KS2, TDC, KLUSTA. This is a simplified version where there is only one group and **all** channel of the file are in the groups. Work only for PRB file. Parameters ---------- recording: RecordingExtractor The recording extractor to channel information probe_file: str Path to probe file. Either .prb or .csv verbose: bool If True, output is verbose Returns --------- Nothing, inplace modification of RecordingExtractor ''' probe_file = Path(probe_file) assert probe_file.suffix == '.prb' probe_dict = read_python(probe_file) assert 'channel_groups' in probe_dict.keys() assert len(probe_dict['channel_groups'] ) == 1, 'load_probe_file_inplace only for one group' cgroup_id = list(probe_dict['channel_groups'].keys())[0] cgroup = probe_dict['channel_groups'][cgroup_id] channel_ids = cgroup['channels'] assert len(channel_ids) == len(recording.get_channel_ids()) # TODO assert equal array sorted for chan_id in channel_ids: recording.set_channel_property(chan_id, 'group', int(cgroup_id)) recording.set_channel_property(chan_id, 'location', cgroup['geometry'][chan_id])
def __init__(self, folder_path): assert HAVE_KLSX, self.installation_mesg klustafolder = Path(folder_path).absolute() config_file = [f for f in klustafolder.iterdir() if f.suffix == '.prm'][0] dat_file = [f for f in klustafolder.iterdir() if f.suffix == '.dat'][0] assert config_file.is_file() and dat_file.is_file(), "Not a valid klusta folder" config = read_python(str(config_file)) sampling_frequency = config['traces']['sample_rate'] n_channels = config['traces']['n_channels'] dtype = config['traces']['dtype'] BinDatRecordingExtractor.__init__(self, file_path=dat_file, sampling_frequency=sampling_frequency, numchan=n_channels, dtype=dtype) self._kwargs = {'folder_path': str(Path(folder_path).absolute())}
def __init__(self, file_or_folder_path, exclude_cluster_groups=None): assert HAVE_KLSX, "To use the KlustaSortingExtractor install h5py: \n\n pip install h5py\n\n" SortingExtractor.__init__(self) kwik_file_or_folder = Path(file_or_folder_path) kwikfile = None klustafolder = None if kwik_file_or_folder.is_file(): assert kwik_file_or_folder.suffix == '.kwik', "Not a '.kwik' file" kwikfile = Path(kwik_file_or_folder).absolute() klustafolder = kwikfile.parent elif kwik_file_or_folder.is_dir(): klustafolder = kwik_file_or_folder kwikfiles = [ f for f in kwik_file_or_folder.iterdir() if f.suffix == '.kwik' ] if len(kwikfiles) == 1: kwikfile = kwikfiles[0] assert kwikfile is not None, "Could not load '.kwik' file" try: config_file = [ f for f in klustafolder.iterdir() if f.suffix == '.prm' ][0] config = read_python(str(config_file)) sampling_frequency = config['traces']['sample_rate'] self._sampling_frequency = sampling_frequency except Exception as e: print("Could not load sampling frequency info") kf_reader = h5py.File(kwikfile, 'r') self._spiketrains = [] self._unit_ids = [] unique_units = [] klusta_units = [] cluster_groups_name = [] groups = [] unit = 0 cs_to_exclude = [] valid_group_names = [ i[1].lower() for i in self.default_cluster_groups.items() ] if exclude_cluster_groups is not None: assert isinstance(exclude_cluster_groups, list), 'exclude_cluster_groups should be a list' for ec in exclude_cluster_groups: assert ec in valid_group_names, f'select exclude names out of: {valid_group_names}' cs_to_exclude.append(ec.lower()) for channel_group in kf_reader.get('/channel_groups'): chan_cluster_id_arr = kf_reader.get( f'/channel_groups/{channel_group}/spikes/clusters/main')[()] chan_cluster_times_arr = kf_reader.get( f'/channel_groups/{channel_group}/spikes/time_samples')[()] chan_cluster_ids = np.unique( chan_cluster_id_arr) # if clusters were merged in gui, # the original id's are still in the kwiktree, but # in this array for cluster_id in chan_cluster_ids: cluster_frame_idx = np.nonzero( chan_cluster_id_arr == cluster_id) # the [()] is a h5py thing st = chan_cluster_times_arr[cluster_frame_idx] assert st.shape[0] > 0, 'no spikes in cluster' cluster_group = kf_reader.get( f'/channel_groups/{channel_group}/clusters/main/{cluster_id}' ).attrs['cluster_group'] assert cluster_group in self.default_cluster_groups.keys( ), f'cluster_group not in "default_dict: {cluster_group}' cluster_group_name = self.default_cluster_groups[cluster_group] if cluster_group_name.lower() in cs_to_exclude: continue self._spiketrains.append(st) klusta_units.append(int(cluster_id)) unique_units.append(unit) unit += 1 groups.append(int(channel_group)) cluster_groups_name.append(cluster_group_name) if len(np.unique(klusta_units)) == len(np.unique(unique_units)): self._unit_ids = klusta_units else: print('Klusta units are not unique! Using unique unit ids') self._unit_ids = unique_units for i, u in enumerate(self._unit_ids): self.set_unit_property(u, 'group', groups[i]) self.set_unit_property(u, 'quality', cluster_groups_name[i].lower()) self._kwargs = { 'file_or_folder_path': str(Path(file_or_folder_path).absolute()) }
def __init__(self, folder_path, exclude_cluster_groups=None, load_waveforms=False, verbose=False): SortingExtractor.__init__(self) phy_folder = Path(folder_path) spike_times = np.load(phy_folder / 'spike_times.npy') spike_templates = np.load(phy_folder / 'spike_templates.npy') if (phy_folder / 'spike_clusters.npy').is_file(): spike_clusters = np.load(phy_folder / 'spike_clusters.npy') else: spike_clusters = spike_templates if (phy_folder / 'amplitudes.npy').is_file(): amplitudes = np.squeeze(np.load(phy_folder / 'amplitudes.npy')) else: amplitudes = np.ones(len(spike_times)) if (phy_folder / 'pc_features.npy').is_file(): pc_features = np.squeeze(np.load(phy_folder / 'pc_features.npy')) else: pc_features = None clust_id = np.unique(spike_clusters) self._unit_ids = list(clust_id) spike_times.astype(int) self.params = read_python(str(phy_folder / 'params.py')) self._sampling_frequency = self.params['sample_rate'] # set unit quality properties csv_tsv_files = [ x for x in phy_folder.iterdir() if x.suffix == '.csv' or x.suffix == '.tsv' ] for f in csv_tsv_files: if f.suffix == '.csv': with f.open() as csv_file: csv_reader = csv.reader(csv_file, delimiter=',') line_count = 0 for row in csv_reader: if line_count == 0: tokens = row[0].split("\t") property_name = tokens[1] else: tokens = row[0].split("\t") if int(tokens[0]) in self.get_unit_ids(): if 'cluster_group' in str(f): self.set_unit_property( int(tokens[0]), 'quality', tokens[1]) elif property_name == 'chan_grp': self.set_unit_property( int(tokens[0]), 'group', tokens[1]) else: if isinstance( tokens[1], (int, np.int, float, np.float, str)): self.set_unit_property( int(tokens[0]), property_name, tokens[1]) line_count += 1 elif f.suffix == '.tsv': with f.open() as csv_file: csv_reader = csv.reader(csv_file, delimiter='\t') line_count = 0 for row in csv_reader: if line_count == 0: property_name = row[1] else: if len(row) == 2: if int(row[0]) in self.get_unit_ids(): if 'cluster_group' in str(f): self.set_unit_property( int(row[0]), 'quality', row[1]) elif property_name == 'chan_grp': self.set_unit_property( int(row[0]), 'group', row[1]) else: if isinstance( row[1], (int, np.int, float, np.float, str)) and len(row) == 2: self.set_unit_property( int(row[0]), property_name, row[1]) line_count += 1 for unit in self.get_unit_ids(): if 'quality' not in self.get_unit_property_names(unit): self.set_unit_property(unit, 'quality', 'unsorted') if exclude_cluster_groups is not None: if len(exclude_cluster_groups) > 0: included_units = [] for u in self.get_unit_ids(): if self.get_unit_property( u, 'quality') not in exclude_cluster_groups: included_units.append(u) else: included_units = self._unit_ids else: included_units = self._unit_ids original_units = self._unit_ids self._unit_ids = included_units # set features self._spiketrains = [] for clust in self._unit_ids: idx = np.where(spike_clusters == clust)[0] self._spiketrains.append(spike_times[idx]) self.set_unit_spike_features(clust, 'amplitudes', amplitudes[idx]) if pc_features is not None: self.set_unit_spike_features(clust, 'pc_features', pc_features[idx]) if load_waveforms: datfile = [ x for x in phy_folder.iterdir() if x.suffix == '.dat' or x.suffix == '.bin' ] recording = BinDatRecordingExtractor( datfile[0], sampling_frequency=float(self.params['sample_rate']), dtype=self.params['dtype'], numchan=self.params['n_channels_dat']) # if channel groups are present, compute waveforms by group if (phy_folder / 'channel_groups.npy').is_file(): channel_groups = np.load(phy_folder / 'channel_groups.npy') assert len(channel_groups) == recording.get_num_channels() recording.set_channel_groups(channel_groups) for u_i, u in enumerate(self.get_unit_ids()): if verbose: print('Computing waveform by group for unit', u) frames_before = int(0.5 / 1000. * recording.get_sampling_frequency()) frames_after = int(2 / 1000. * recording.get_sampling_frequency()) spiketrain = self.get_unit_spike_train(u) if 'group' in self.get_unit_property_names(u): group_idx = np.where(channel_groups == int( self.get_unit_property(u, 'group')))[0] wf = recording.get_snippets( reference_frames=spiketrain, snippet_len=[frames_before, frames_after], channel_ids=group_idx) else: wf = recording.get_snippets( reference_frames=spiketrain, snippet_len=[frames_before, frames_after]) max_chan = np.unravel_index( np.argmin(np.mean(wf, axis=0)), np.mean(wf, axis=0).shape)[0] group = recording.get_channel_groups(int(max_chan)) self.set_unit_property(u, 'group', group) group_idx = np.where(channel_groups == group)[0] wf = wf[:, group_idx] self.set_unit_spike_features(u, 'waveforms', wf) else: for u_i, u in enumerate(self.get_unit_ids()): if verbose: print('Computing full waveform for unit', u) frames_before = 0.5 * recording.get_sampling_frequency() frames_after = 2 * recording.get_sampling_frequency() spiketrain = self.get_unit_spike_train(u) wf = recording.get_snippets( reference_frames=spiketrain, snippet_len=[int(frames_before), int(frames_after)]) self.set_unit_spike_features(u, 'waveforms', wf) self._kwargs = { 'folder_path': str(Path(folder_path).absolute()), 'exclude_cluster_groups': exclude_cluster_groups, 'load_waveforms': load_waveforms, 'verbose': verbose }
def __init__(self, folder_path: PathType, exclude_cluster_groups: Optional[list] = None): SortingExtractor.__init__(self) phy_folder = Path(folder_path) spike_times = np.load(phy_folder / 'spike_times.npy') spike_templates = np.load(phy_folder / 'spike_templates.npy') if (phy_folder / 'spike_clusters.npy').is_file(): spike_clusters = np.load(phy_folder / 'spike_clusters.npy') else: spike_clusters = spike_templates if (phy_folder / 'amplitudes.npy').is_file(): amplitudes = np.squeeze(np.load(phy_folder / 'amplitudes.npy')) else: amplitudes = np.ones(len(spike_times)) if (phy_folder / 'pc_features.npy').is_file(): pc_features = np.squeeze(np.load(phy_folder / 'pc_features.npy')) else: pc_features = None clust_id = np.unique(spike_clusters) self._unit_ids = list(clust_id) spike_times.astype(int) self.params = read_python(str(phy_folder / 'params.py')) self._sampling_frequency = self.params['sample_rate'] # set unit quality properties csv_tsv_files = [ x for x in phy_folder.iterdir() if x.suffix == '.csv' or x.suffix == '.tsv' ] for f in csv_tsv_files: if f.suffix == '.csv': with f.open() as csv_file: csv_reader = csv.reader(csv_file, delimiter=',') line_count = 0 for row in csv_reader: if line_count == 0: tokens = row[0].split("\t") property_name = tokens[1] else: tokens = row[0].split("\t") if int(tokens[0]) in self.get_unit_ids(): if 'cluster_group' in str(f): self.set_unit_property( int(tokens[0]), 'quality', tokens[1]) elif property_name == 'chan_grp': self.set_unit_property( int(tokens[0]), 'group', tokens[1]) else: if isinstance( tokens[1], (int, np.int, float, np.float, str)): self.set_unit_property( int(tokens[0]), property_name, tokens[1]) line_count += 1 elif f.suffix == '.tsv': with f.open() as csv_file: csv_reader = csv.reader(csv_file, delimiter='\t') line_count = 0 for row in csv_reader: if line_count == 0: property_name = row[1] else: if len(row) == 2: if int(row[0]) in self.get_unit_ids(): if 'cluster_group' in str(f): self.set_unit_property( int(row[0]), 'quality', row[1]) elif property_name == 'chan_grp': self.set_unit_property( int(row[0]), 'group', row[1]) else: if isinstance( row[1], (int, np.int, float, np.float, str)) and len(row) == 2: self.set_unit_property( int(row[0]), property_name, row[1]) line_count += 1 for unit in self.get_unit_ids(): if 'quality' not in self.get_unit_property_names(unit): self.set_unit_property(unit, 'quality', 'unsorted') if exclude_cluster_groups is not None: if len(exclude_cluster_groups) > 0: included_units = [] for u in self.get_unit_ids(): if self.get_unit_property( u, 'quality') not in exclude_cluster_groups: included_units.append(u) else: included_units = self._unit_ids else: included_units = self._unit_ids original_units = self._unit_ids self._unit_ids = included_units # set features self._spiketrains = [] for clust in self._unit_ids: idx = np.where(spike_clusters == clust)[0] self._spiketrains.append(spike_times[idx]) self.set_unit_spike_features(clust, 'amplitudes', amplitudes[idx]) if pc_features is not None: self.set_unit_spike_features(clust, 'pc_features', pc_features[idx]) self._kwargs = { 'folder_path': str(Path(folder_path).absolute()), 'exclude_cluster_groups': exclude_cluster_groups }