Ejemplo n.º 1
0
    def __init__(self, folder_path):
        RecordingExtractor.__init__(self)
        phy_folder = Path(folder_path)

        self.params = read_python(str(phy_folder / 'params.py'))
        datfile = [x for x in phy_folder.iterdir() if x.suffix == '.dat' or x.suffix == '.bin']

        if (phy_folder / 'channel_map_si.npy').is_file():
            channel_map = list(np.squeeze(np.load(phy_folder / 'channel_map_si.npy')))
            assert len(channel_map) == self.params['n_channels_dat']
        elif (phy_folder / 'channel_map.npy').is_file():
            channel_map = list(np.squeeze(np.load(phy_folder / 'channel_map.npy')))
            assert len(channel_map) == self.params['n_channels_dat']
        else:
            channel_map = list(range(self.params['n_channels_dat']))

        BinDatRecordingExtractor.__init__(self, datfile[0], sampling_frequency=float(self.params['sample_rate']),
                                          dtype=self.params['dtype'], numchan=self.params['n_channels_dat'],
                                          recording_channels=list(channel_map))

        if (phy_folder / 'channel_groups.npy').is_file():
            channel_groups = np.load(phy_folder / 'channel_groups.npy')
            assert len(channel_groups) == self.get_num_channels()
            self.set_channel_groups(channel_groups)

        if (phy_folder / 'channel_positions.npy').is_file():
            channel_locations = np.load(phy_folder / 'channel_positions.npy')
            assert len(channel_locations) == self.get_num_channels()
            self.set_channel_locations(channel_locations)

        self._kwargs = {'folder_path': str(Path(folder_path).absolute())}
Ejemplo n.º 2
0
    def __init__(self, dir_path):
        RecordingExtractor.__init__(self)
        phy_folder = Path(dir_path)

        self.params = read_python(str(phy_folder / 'params.py'))
        datfile = [x for x in phy_folder.iterdir() if x.suffix == '.dat' or x.suffix == '.bin']

        if (phy_folder / 'channel_map_si.npy').is_file():
            channel_map = list(np.squeeze(np.load(phy_folder / 'channel_map_si.npy')))
            assert len(channel_map) == self.params['n_channels_dat']
        elif (phy_folder / 'channel_map.npy').is_file():
            channel_map = list(np.squeeze(np.load(phy_folder / 'channel_map.npy')))
            assert len(channel_map) == self.params['n_channels_dat']
        else:
            channel_map = list(range(self.params['n_channels_dat']))

        BinDatRecordingExtractor.__init__(self, datfile[0], samplerate=float(self.params['sample_rate']),
                                          dtype=self.params['dtype'], numchan=self.params['n_channels_dat'],
                                          recording_channels=list(channel_map))

        if (phy_folder / 'channel_groups.npy').is_file():
            channel_groups = np.load(phy_folder / 'channel_groups.npy')
            assert len(channel_groups) == self.get_num_channels()
            for (ch, cg) in zip(self.get_channel_ids(), channel_groups):
                self.set_channel_property(ch, 'group', cg)

        if (phy_folder / 'channel_positions.npy').is_file():
            channel_locations = np.load(phy_folder / 'channel_positions.npy')
            assert len(channel_locations) == self.get_num_channels()
            for (ch, loc) in zip(self.get_channel_ids(), channel_locations):
                self.set_channel_property(ch, 'location', loc)
Ejemplo n.º 3
0
    def __init__(self, kwik_file_or_folder):
        assert HAVE_KLSX, "To use the KlustaSortingExtractor install h5py: \n\n pip install h5py\n\n"
        SortingExtractor.__init__(self)
        kwik_file_or_folder = Path(kwik_file_or_folder)
        kwikfile = None
        klustafolder = None
        if kwik_file_or_folder.is_file():
            assert kwik_file_or_folder.suffix == '.kwik', "Not a '.kwik' file"
            kwikfile = Path(kwik_file_or_folder).absolute()
            klustafolder = kwikfile.parent
        elif kwik_file_or_folder.is_dir():
            klustafolder = kwik_file_or_folder
            kwikfiles = [f for f in kwik_file_or_folder.iterdir() if f.suffix == '.kwik']
            if len(kwikfiles) == 1:
                kwikfile = kwikfiles[0]
        assert kwikfile is not None, "Could not load '.kwik' file"

        try:
            config_file = [f for f in klustafolder.iterdir() if f.suffix == '.prm'][0]
            config = read_python(str(config_file))
            sample_rate = config['traces']['sample_rate']
            self._sampling_frequency = sample_rate
        except Exception as e:
            print("Could not load sampling frequency info")

        F = h5py.File(kwikfile)
        channel_groups = F.get('channel_groups')
        self._spiketrains = []
        self._unit_ids = []
        unique_units = []
        klusta_units = []
        groups = []
        unit = 0
        for cgroup in channel_groups:
            group_id = int(cgroup)
            try:
                cluster_ids = channel_groups[cgroup]['clusters']['main']
            except Exception as e:
                print('Unable to extract clusters from', kwikfile)
                continue
            for cluster_id in channel_groups[cgroup]['clusters']['main']:
                clusters = np.array(channel_groups[cgroup]['spikes']['clusters']['main'])
                idx = np.nonzero(clusters == int(cluster_id))
                st = np.array(channel_groups[cgroup]['spikes']['time_samples'])[idx]
                self._spiketrains.append(st)
                klusta_units.append(int(cluster_id))
                unique_units.append(unit)
                unit += 1
                groups.append(group_id)
        if len(np.unique(klusta_units)) == len(np.unique(unique_units)):
            self._unit_ids = klusta_units
        else:
            print('Klusta units are not unique! Using unique unit ids')
            self._unit_ids = unique_units
        for i, u in enumerate(self._unit_ids):
            self.set_unit_property(u, 'group', groups[i])
Ejemplo n.º 4
0
    def __init__(self, klustafolder):
        assert HAVE_KLSX, "To use the KlustaSortingExtractor install h5py: \n\n pip install h5py\n\n"
        klustafolder = Path(klustafolder).absolute()
        config_file = [f for f in klustafolder.iterdir() if f.suffix == '.prm'][0]
        dat_file = [f for f in klustafolder.iterdir() if f.suffix == '.dat'][0]
        assert config_file.is_file() and dat_file.is_file(), "Not a valid klusta folder"
        config = read_python(str(config_file))
        sample_rate = config['traces']['sample_rate']
        n_channels = config['traces']['n_channels']
        dtype = config['traces']['dtype']

        BinDatRecordingExtractor.__init__(self, datfile=dat_file, samplerate=sample_rate, numchan=n_channels,
                                          dtype=dtype)
Ejemplo n.º 5
0
def load_probe_file_inplace(recording, probe_file):
    '''
    This is a locally modified version of spikeextractor.extraction_tools.load_probe_file.
    But it do load "in place" and do NOT return a SubRecordingExtractor.
    
    This is usefull to not copy local raw data for KS, KS2, TDC, KLUSTA.
    
    This is a simplified version where there is only one group and **all** channel 
    of the file are in the groups.
    
    Work only for PRB file.
    
    Parameters
    ----------
    recording: RecordingExtractor
        The recording extractor to channel information
    probe_file: str
        Path to probe file. Either .prb or .csv
    verbose: bool
        If True, output is verbose

    Returns
    ---------
    
    Nothing, inplace modification of RecordingExtractor
    
    '''
    probe_file = Path(probe_file)

    assert probe_file.suffix == '.prb'

    probe_dict = read_python(probe_file)

    assert 'channel_groups' in probe_dict.keys()

    assert len(probe_dict['channel_groups']
               ) == 1, 'load_probe_file_inplace only for one group'

    cgroup_id = list(probe_dict['channel_groups'].keys())[0]
    cgroup = probe_dict['channel_groups'][cgroup_id]

    channel_ids = cgroup['channels']
    assert len(channel_ids) == len(recording.get_channel_ids())
    # TODO assert equal array sorted

    for chan_id in channel_ids:
        recording.set_channel_property(chan_id, 'group', int(cgroup_id))
        recording.set_channel_property(chan_id, 'location',
                                       cgroup['geometry'][chan_id])
Ejemplo n.º 6
0
    def __init__(self, folder_path):
        assert HAVE_KLSX, self.installation_mesg
        klustafolder = Path(folder_path).absolute()
        config_file = [f for f in klustafolder.iterdir() if f.suffix == '.prm'][0]
        dat_file = [f for f in klustafolder.iterdir() if f.suffix == '.dat'][0]
        assert config_file.is_file() and dat_file.is_file(), "Not a valid klusta folder"
        config = read_python(str(config_file))
        sampling_frequency = config['traces']['sample_rate']
        n_channels = config['traces']['n_channels']
        dtype = config['traces']['dtype']

        BinDatRecordingExtractor.__init__(self, file_path=dat_file, sampling_frequency=sampling_frequency, numchan=n_channels,
                                          dtype=dtype)

        self._kwargs = {'folder_path': str(Path(folder_path).absolute())}
Ejemplo n.º 7
0
    def __init__(self, file_or_folder_path, exclude_cluster_groups=None):
        assert HAVE_KLSX, "To use the KlustaSortingExtractor install h5py: \n\n pip install h5py\n\n"
        SortingExtractor.__init__(self)
        kwik_file_or_folder = Path(file_or_folder_path)
        kwikfile = None
        klustafolder = None
        if kwik_file_or_folder.is_file():
            assert kwik_file_or_folder.suffix == '.kwik', "Not a '.kwik' file"
            kwikfile = Path(kwik_file_or_folder).absolute()
            klustafolder = kwikfile.parent
        elif kwik_file_or_folder.is_dir():
            klustafolder = kwik_file_or_folder
            kwikfiles = [
                f for f in kwik_file_or_folder.iterdir() if f.suffix == '.kwik'
            ]
            if len(kwikfiles) == 1:
                kwikfile = kwikfiles[0]
        assert kwikfile is not None, "Could not load '.kwik' file"

        try:
            config_file = [
                f for f in klustafolder.iterdir() if f.suffix == '.prm'
            ][0]
            config = read_python(str(config_file))
            sampling_frequency = config['traces']['sample_rate']
            self._sampling_frequency = sampling_frequency
        except Exception as e:
            print("Could not load sampling frequency info")

        kf_reader = h5py.File(kwikfile, 'r')
        self._spiketrains = []
        self._unit_ids = []
        unique_units = []
        klusta_units = []
        cluster_groups_name = []
        groups = []
        unit = 0

        cs_to_exclude = []
        valid_group_names = [
            i[1].lower() for i in self.default_cluster_groups.items()
        ]
        if exclude_cluster_groups is not None:
            assert isinstance(exclude_cluster_groups,
                              list), 'exclude_cluster_groups should be a list'
            for ec in exclude_cluster_groups:
                assert ec in valid_group_names, f'select exclude names out of: {valid_group_names}'
                cs_to_exclude.append(ec.lower())

        for channel_group in kf_reader.get('/channel_groups'):
            chan_cluster_id_arr = kf_reader.get(
                f'/channel_groups/{channel_group}/spikes/clusters/main')[()]
            chan_cluster_times_arr = kf_reader.get(
                f'/channel_groups/{channel_group}/spikes/time_samples')[()]
            chan_cluster_ids = np.unique(
                chan_cluster_id_arr)  # if clusters were merged in gui,
            # the original id's are still in the kwiktree, but
            # in this array

            for cluster_id in chan_cluster_ids:
                cluster_frame_idx = np.nonzero(
                    chan_cluster_id_arr ==
                    cluster_id)  # the [()] is a h5py thing
                st = chan_cluster_times_arr[cluster_frame_idx]
                assert st.shape[0] > 0, 'no spikes in cluster'
                cluster_group = kf_reader.get(
                    f'/channel_groups/{channel_group}/clusters/main/{cluster_id}'
                ).attrs['cluster_group']

                assert cluster_group in self.default_cluster_groups.keys(
                ), f'cluster_group not in "default_dict: {cluster_group}'
                cluster_group_name = self.default_cluster_groups[cluster_group]

                if cluster_group_name.lower() in cs_to_exclude:
                    continue

                self._spiketrains.append(st)
                klusta_units.append(int(cluster_id))
                unique_units.append(unit)
                unit += 1
                groups.append(int(channel_group))
                cluster_groups_name.append(cluster_group_name)

        if len(np.unique(klusta_units)) == len(np.unique(unique_units)):
            self._unit_ids = klusta_units
        else:
            print('Klusta units are not unique! Using unique unit ids')
            self._unit_ids = unique_units
        for i, u in enumerate(self._unit_ids):
            self.set_unit_property(u, 'group', groups[i])
            self.set_unit_property(u, 'quality',
                                   cluster_groups_name[i].lower())

        self._kwargs = {
            'file_or_folder_path': str(Path(file_or_folder_path).absolute())
        }
Ejemplo n.º 8
0
    def __init__(self,
                 folder_path,
                 exclude_cluster_groups=None,
                 load_waveforms=False,
                 verbose=False):
        SortingExtractor.__init__(self)
        phy_folder = Path(folder_path)

        spike_times = np.load(phy_folder / 'spike_times.npy')
        spike_templates = np.load(phy_folder / 'spike_templates.npy')

        if (phy_folder / 'spike_clusters.npy').is_file():
            spike_clusters = np.load(phy_folder / 'spike_clusters.npy')
        else:
            spike_clusters = spike_templates

        if (phy_folder / 'amplitudes.npy').is_file():
            amplitudes = np.squeeze(np.load(phy_folder / 'amplitudes.npy'))
        else:
            amplitudes = np.ones(len(spike_times))

        if (phy_folder / 'pc_features.npy').is_file():
            pc_features = np.squeeze(np.load(phy_folder / 'pc_features.npy'))
        else:
            pc_features = None

        clust_id = np.unique(spike_clusters)
        self._unit_ids = list(clust_id)
        spike_times.astype(int)
        self.params = read_python(str(phy_folder / 'params.py'))
        self._sampling_frequency = self.params['sample_rate']

        # set unit quality properties
        csv_tsv_files = [
            x for x in phy_folder.iterdir()
            if x.suffix == '.csv' or x.suffix == '.tsv'
        ]
        for f in csv_tsv_files:
            if f.suffix == '.csv':
                with f.open() as csv_file:
                    csv_reader = csv.reader(csv_file, delimiter=',')
                    line_count = 0
                    for row in csv_reader:
                        if line_count == 0:
                            tokens = row[0].split("\t")
                            property_name = tokens[1]
                        else:
                            tokens = row[0].split("\t")
                            if int(tokens[0]) in self.get_unit_ids():
                                if 'cluster_group' in str(f):
                                    self.set_unit_property(
                                        int(tokens[0]), 'quality', tokens[1])
                                elif property_name == 'chan_grp':
                                    self.set_unit_property(
                                        int(tokens[0]), 'group', tokens[1])
                                else:
                                    if isinstance(
                                            tokens[1],
                                        (int, np.int, float, np.float, str)):
                                        self.set_unit_property(
                                            int(tokens[0]), property_name,
                                            tokens[1])
                            line_count += 1
            elif f.suffix == '.tsv':
                with f.open() as csv_file:
                    csv_reader = csv.reader(csv_file, delimiter='\t')
                    line_count = 0
                    for row in csv_reader:
                        if line_count == 0:
                            property_name = row[1]
                        else:
                            if len(row) == 2:
                                if int(row[0]) in self.get_unit_ids():
                                    if 'cluster_group' in str(f):
                                        self.set_unit_property(
                                            int(row[0]), 'quality', row[1])
                                    elif property_name == 'chan_grp':
                                        self.set_unit_property(
                                            int(row[0]), 'group', row[1])
                                    else:
                                        if isinstance(
                                                row[1],
                                            (int, np.int, float, np.float,
                                             str)) and len(row) == 2:
                                            self.set_unit_property(
                                                int(row[0]), property_name,
                                                row[1])
                        line_count += 1

        for unit in self.get_unit_ids():
            if 'quality' not in self.get_unit_property_names(unit):
                self.set_unit_property(unit, 'quality', 'unsorted')

        if exclude_cluster_groups is not None:
            if len(exclude_cluster_groups) > 0:
                included_units = []
                for u in self.get_unit_ids():
                    if self.get_unit_property(
                            u, 'quality') not in exclude_cluster_groups:
                        included_units.append(u)
            else:
                included_units = self._unit_ids
        else:
            included_units = self._unit_ids

        original_units = self._unit_ids
        self._unit_ids = included_units
        # set features
        self._spiketrains = []
        for clust in self._unit_ids:
            idx = np.where(spike_clusters == clust)[0]
            self._spiketrains.append(spike_times[idx])
            self.set_unit_spike_features(clust, 'amplitudes', amplitudes[idx])
            if pc_features is not None:
                self.set_unit_spike_features(clust, 'pc_features',
                                             pc_features[idx])

        if load_waveforms:
            datfile = [
                x for x in phy_folder.iterdir()
                if x.suffix == '.dat' or x.suffix == '.bin'
            ]

            recording = BinDatRecordingExtractor(
                datfile[0],
                sampling_frequency=float(self.params['sample_rate']),
                dtype=self.params['dtype'],
                numchan=self.params['n_channels_dat'])
            # if channel groups are present, compute waveforms by group
            if (phy_folder / 'channel_groups.npy').is_file():
                channel_groups = np.load(phy_folder / 'channel_groups.npy')
                assert len(channel_groups) == recording.get_num_channels()
                recording.set_channel_groups(channel_groups)
                for u_i, u in enumerate(self.get_unit_ids()):
                    if verbose:
                        print('Computing waveform by group for unit', u)
                    frames_before = int(0.5 / 1000. *
                                        recording.get_sampling_frequency())
                    frames_after = int(2 / 1000. *
                                       recording.get_sampling_frequency())
                    spiketrain = self.get_unit_spike_train(u)
                    if 'group' in self.get_unit_property_names(u):
                        group_idx = np.where(channel_groups == int(
                            self.get_unit_property(u, 'group')))[0]
                        wf = recording.get_snippets(
                            reference_frames=spiketrain,
                            snippet_len=[frames_before, frames_after],
                            channel_ids=group_idx)
                    else:
                        wf = recording.get_snippets(
                            reference_frames=spiketrain,
                            snippet_len=[frames_before, frames_after])
                        max_chan = np.unravel_index(
                            np.argmin(np.mean(wf, axis=0)),
                            np.mean(wf, axis=0).shape)[0]
                        group = recording.get_channel_groups(int(max_chan))
                        self.set_unit_property(u, 'group', group)
                        group_idx = np.where(channel_groups == group)[0]
                        wf = wf[:, group_idx]
                    self.set_unit_spike_features(u, 'waveforms', wf)
            else:
                for u_i, u in enumerate(self.get_unit_ids()):
                    if verbose:
                        print('Computing full waveform for unit', u)
                    frames_before = 0.5 * recording.get_sampling_frequency()
                    frames_after = 2 * recording.get_sampling_frequency()
                    spiketrain = self.get_unit_spike_train(u)
                    wf = recording.get_snippets(
                        reference_frames=spiketrain,
                        snippet_len=[int(frames_before),
                                     int(frames_after)])
                    self.set_unit_spike_features(u, 'waveforms', wf)
        self._kwargs = {
            'folder_path': str(Path(folder_path).absolute()),
            'exclude_cluster_groups': exclude_cluster_groups,
            'load_waveforms': load_waveforms,
            'verbose': verbose
        }
Ejemplo n.º 9
0
    def __init__(self,
                 folder_path: PathType,
                 exclude_cluster_groups: Optional[list] = None):
        SortingExtractor.__init__(self)
        phy_folder = Path(folder_path)

        spike_times = np.load(phy_folder / 'spike_times.npy')
        spike_templates = np.load(phy_folder / 'spike_templates.npy')

        if (phy_folder / 'spike_clusters.npy').is_file():
            spike_clusters = np.load(phy_folder / 'spike_clusters.npy')
        else:
            spike_clusters = spike_templates

        if (phy_folder / 'amplitudes.npy').is_file():
            amplitudes = np.squeeze(np.load(phy_folder / 'amplitudes.npy'))
        else:
            amplitudes = np.ones(len(spike_times))

        if (phy_folder / 'pc_features.npy').is_file():
            pc_features = np.squeeze(np.load(phy_folder / 'pc_features.npy'))
        else:
            pc_features = None

        clust_id = np.unique(spike_clusters)
        self._unit_ids = list(clust_id)
        spike_times.astype(int)
        self.params = read_python(str(phy_folder / 'params.py'))
        self._sampling_frequency = self.params['sample_rate']

        # set unit quality properties
        csv_tsv_files = [
            x for x in phy_folder.iterdir()
            if x.suffix == '.csv' or x.suffix == '.tsv'
        ]
        for f in csv_tsv_files:
            if f.suffix == '.csv':
                with f.open() as csv_file:
                    csv_reader = csv.reader(csv_file, delimiter=',')
                    line_count = 0
                    for row in csv_reader:
                        if line_count == 0:
                            tokens = row[0].split("\t")
                            property_name = tokens[1]
                        else:
                            tokens = row[0].split("\t")
                            if int(tokens[0]) in self.get_unit_ids():
                                if 'cluster_group' in str(f):
                                    self.set_unit_property(
                                        int(tokens[0]), 'quality', tokens[1])
                                elif property_name == 'chan_grp':
                                    self.set_unit_property(
                                        int(tokens[0]), 'group', tokens[1])
                                else:
                                    if isinstance(
                                            tokens[1],
                                        (int, np.int, float, np.float, str)):
                                        self.set_unit_property(
                                            int(tokens[0]), property_name,
                                            tokens[1])
                            line_count += 1
            elif f.suffix == '.tsv':
                with f.open() as csv_file:
                    csv_reader = csv.reader(csv_file, delimiter='\t')
                    line_count = 0
                    for row in csv_reader:
                        if line_count == 0:
                            property_name = row[1]
                        else:
                            if len(row) == 2:
                                if int(row[0]) in self.get_unit_ids():
                                    if 'cluster_group' in str(f):
                                        self.set_unit_property(
                                            int(row[0]), 'quality', row[1])
                                    elif property_name == 'chan_grp':
                                        self.set_unit_property(
                                            int(row[0]), 'group', row[1])
                                    else:
                                        if isinstance(
                                                row[1],
                                            (int, np.int, float, np.float,
                                             str)) and len(row) == 2:
                                            self.set_unit_property(
                                                int(row[0]), property_name,
                                                row[1])
                        line_count += 1

        for unit in self.get_unit_ids():
            if 'quality' not in self.get_unit_property_names(unit):
                self.set_unit_property(unit, 'quality', 'unsorted')

        if exclude_cluster_groups is not None:
            if len(exclude_cluster_groups) > 0:
                included_units = []
                for u in self.get_unit_ids():
                    if self.get_unit_property(
                            u, 'quality') not in exclude_cluster_groups:
                        included_units.append(u)
            else:
                included_units = self._unit_ids
        else:
            included_units = self._unit_ids

        original_units = self._unit_ids
        self._unit_ids = included_units
        # set features
        self._spiketrains = []
        for clust in self._unit_ids:
            idx = np.where(spike_clusters == clust)[0]
            self._spiketrains.append(spike_times[idx])
            self.set_unit_spike_features(clust, 'amplitudes', amplitudes[idx])
            if pc_features is not None:
                self.set_unit_spike_features(clust, 'pc_features',
                                             pc_features[idx])

        self._kwargs = {
            'folder_path': str(Path(folder_path).absolute()),
            'exclude_cluster_groups': exclude_cluster_groups
        }