Пример #1
0
    def __init__(self,
                 folder_path,
                 exclude_cluster_groups=None,
                 keep_good_only=False):
        try:
            import pandas as pd
            HAVE_PD = True
        except ImportError:
            HAVE_PD = False
        assert HAVE_PD, self.installation_mesg

        phy_folder = Path(folder_path)

        spike_times = np.load(phy_folder / 'spike_times.npy')
        spike_templates = np.load(phy_folder / 'spike_templates.npy')

        if (phy_folder / 'spike_clusters.npy').is_file():
            spike_clusters = np.load(phy_folder / 'spike_clusters.npy')
        else:
            spike_clusters = spike_templates

        clust_id = np.unique(spike_clusters)
        unit_ids = list(clust_id)
        spike_times.astype(int)
        params = read_python(str(phy_folder / 'params.py'))
        sampling_frequency = params['sample_rate']

        # try to load cluster info
        cluster_info_files = [
            p for p in phy_folder.iterdir()
            if p.suffix in ['.csv', '.tsv'] and "cluster_info" in p.name
        ]

        if len(cluster_info_files) == 1:
            # load properties from cluster_info file
            cluster_info_file = cluster_info_files[0]
            if cluster_info_file.suffix == ".tsv":
                delimeter = "\t"
            else:
                delimeter = ","
            cluster_info = pd.read_csv(cluster_info_file, delimiter=delimeter)
        else:
            # load properties from other tsv/csv files
            all_property_files = [
                p for p in phy_folder.iterdir()
                if p.suffix in ['.csv', '.tsv']
            ]

            cluster_info = None
            for file in all_property_files:
                if file.suffix == ".tsv":
                    delimeter = "\t"
                else:
                    delimeter = ","
                new_property = pd.read_csv(file, delimiter=delimeter)
                if cluster_info is None:
                    cluster_info = new_property
                else:
                    cluster_info = pd.merge(cluster_info,
                                            new_property,
                                            on='cluster_id')

        # in case no tsv/csv files are found populate cluster info with minimal info
        if cluster_info is None:
            cluster_info = pd.DataFrame({'cluster_id': unit_ids})
            cluster_info['group'] = ['unsorted'] * len(unit_ids)

        if exclude_cluster_groups is not None:
            if isinstance(exclude_cluster_groups, str):
                cluster_info = cluster_info.query(
                    f"group != '{exclude_cluster_groups}'")
            elif isinstance(exclude_cluster_groups, list):
                if len(exclude_cluster_groups) > 0:
                    for exclude_group in exclude_cluster_groups:
                        cluster_info = cluster_info.query(
                            f"group != '{exclude_group}'")

        if keep_good_only and "KSLabel" in cluster_info.columns:
            cluster_info = cluster_info.query(f"KSLabel != 'good'")
        unit_ids = cluster_info["cluster_id"].values

        BaseSorting.__init__(self, sampling_frequency, unit_ids)

        del cluster_info["cluster_id"]
        for prop_name in cluster_info.columns:
            if prop_name in ['chan_grp', 'ch_group']:
                self.set_property(key="group", values=cluster_info[prop_name])
            elif prop_name != "group":
                self.set_property(key=prop_name,
                                  values=cluster_info[prop_name])
            elif prop_name == "group":
                # rename group property to 'quality'
                self.set_property(key="quality",
                                  values=cluster_info[prop_name])

        self.add_sorting_segment(PhySortingSegment(spike_times,
                                                   spike_clusters))
Пример #2
0
    def __init__(self,
                 folder_path,
                 sampling_frequency=None,
                 user='******',
                 det_sign='both'):

        folder_path = Path(folder_path)
        assert folder_path.is_dir(), 'Folder {} doesn\'t exist'.format(
            folder_path)
        if sampling_frequency is None:
            h5_path = str(folder_path) + '.h5'
            if Path(h5_path).exists():
                with h5py.File(h5_path, mode='r') as f:
                    sampling_frequency = f['sr'][0]

        # ~ self.set_sampling_frequency(sampling_frequency)
        det_file = str(folder_path / Path('data_' + folder_path.stem + '.h5'))
        sort_cat_files = []
        for sign in ['neg', 'pos']:
            if det_sign in ['both', sign]:
                sort_cat_file = folder_path / Path(
                    'sort_{}_{}/sort_cat.h5'.format(sign, user))
                if sort_cat_file.exists():
                    sort_cat_files.append((sign, str(sort_cat_file)))

        unit_counter = 0
        spiketrains = {}
        metadata = {}
        unsorted = []
        with h5py.File(det_file, mode='r') as fdet:
            for sign, sfile in sort_cat_files:
                with h5py.File(sfile, mode='r') as f:
                    sp_class = f['classes'][()]
                    gaux = f['groups'][()]
                    groups = {
                        g: gaux[gaux[:, 1] == g, 0]
                        for g in np.unique(gaux[:, 1])
                    }  # array of classes per group
                    group_type = {
                        group: g_type
                        for group, g_type in f['types'][()]
                    }
                    sp_index = f['index'][()]

                times_css = fdet[sign]['times'][()]
                for gr, cls in groups.items():

                    spiketrains[unit_counter] = np.rint(
                        times_css[sp_index[np.isin(sp_class, cls)]] *
                        (sampling_frequency / 1000))
                    metadata[unit_counter] = {'group_type': group_type[gr]}
                    unit_counter = unit_counter + 1
        unit_ids = np.arange(unit_counter, dtype='int64')
        BaseSorting.__init__(self, sampling_frequency, unit_ids)
        self.add_sorting_segment(CombinatoSortingSegment(spiketrains))
        self.set_property(
            'unsorted',
            np.array(
                [metadata[u]['group_type'] == 0 for u in range(unit_counter)]))
        self.set_property(
            'artifact',
            np.array([
                metadata[u]['group_type'] == -1 for u in range(unit_counter)
            ]))
        self._kwargs = {
            'folder_path': str(folder_path),
            'user': user,
            'det_sign': det_sign
        }
Пример #3
0
    def __init__(self, file_or_folder_path, exclude_cluster_groups=None):
        assert HAVE_H5PY, self.installation_mesg
        # ~ SortingExtractor.__init__(self)

        kwik_file_or_folder = Path(file_or_folder_path)
        kwikfile = None
        klustafolder = None
        if kwik_file_or_folder.is_file():
            assert kwik_file_or_folder.suffix == '.kwik', "Not a '.kwik' file"
            kwikfile = Path(kwik_file_or_folder).absolute()
            klustafolder = kwikfile.parent
        elif kwik_file_or_folder.is_dir():
            klustafolder = kwik_file_or_folder
            kwikfiles = [
                f for f in kwik_file_or_folder.iterdir() if f.suffix == '.kwik'
            ]
            if len(kwikfiles) == 1:
                kwikfile = kwikfiles[0]
        assert kwikfile is not None, "Could not load '.kwik' file"

        try:
            config_file = [
                f for f in klustafolder.iterdir() if f.suffix == '.prm'
            ][0]
            config = read_python(str(config_file))
            sampling_frequency = config['traces']['sample_rate']
        except Exception as e:
            print("Could not load sampling frequency info")

        kf_reader = h5py.File(kwikfile, 'r')
        spiketrains = []
        unit_ids = []
        unique_units = []
        klusta_units = []
        cluster_groups_name = []
        groups = []
        unit = 0

        cs_to_exclude = []
        valid_group_names = [
            i[1].lower() for i in self.default_cluster_groups.items()
        ]
        if exclude_cluster_groups is not None:
            assert isinstance(exclude_cluster_groups,
                              list), 'exclude_cluster_groups should be a list'
            for ec in exclude_cluster_groups:
                assert ec in valid_group_names, f'select exclude names out of: {valid_group_names}'
                cs_to_exclude.append(ec.lower())

        for channel_group in kf_reader.get('/channel_groups'):
            chan_cluster_id_arr = kf_reader.get(
                f'/channel_groups/{channel_group}/spikes/clusters/main')[()]
            chan_cluster_times_arr = kf_reader.get(
                f'/channel_groups/{channel_group}/spikes/time_samples')[()]
            chan_cluster_ids = np.unique(
                chan_cluster_id_arr)  # if clusters were merged in gui,
            # the original id's are still in the kwiktree, but
            # in this array

            for cluster_id in chan_cluster_ids:
                cluster_frame_idx = np.nonzero(
                    chan_cluster_id_arr ==
                    cluster_id)  # the [()] is a h5py thing
                st = chan_cluster_times_arr[cluster_frame_idx]
                assert st.shape[0] > 0, 'no spikes in cluster'
                cluster_group = kf_reader.get(
                    f'/channel_groups/{channel_group}/clusters/main/{cluster_id}'
                ).attrs['cluster_group']

                assert cluster_group in self.default_cluster_groups.keys(
                ), f'cluster_group not in "default_dict: {cluster_group}'
                cluster_group_name = self.default_cluster_groups[cluster_group]

                if cluster_group_name.lower() in cs_to_exclude:
                    continue

                spiketrains.append(st)

                klusta_units.append(int(cluster_id))
                unique_units.append(unit)
                unit += 1
                groups.append(int(channel_group))
                cluster_groups_name.append(cluster_group_name)

        if len(np.unique(klusta_units)) == len(np.unique(unique_units)):
            unit_ids = klusta_units
        else:
            print('Klusta units are not unique! Using unique unit ids')
            unit_ids = unique_units

        BaseSorting.__init__(self, sampling_frequency, unit_ids)
        self.is_dumpable = False

        self.add_sorting_segment(KlustSortingSegment(unit_ids, spiketrains))

        self.set_property('group', groups)
        quality = [e.lower() for e in cluster_groups_name]
        self.set_property('quality', quality)

        self._kwargs = {
            'file_or_folder_path': str(Path(file_or_folder_path).absolute())
        }
Пример #4
0
 def __init__(self, sampling_frequency, unit_ids=[]):
     BaseSorting.__init__(self, sampling_frequency, unit_ids)
     self.is_dumpable = False
Пример #5
0
    def __init__(self, file_path, keep_good_only=True):
        MatlabHelper.__init__(self, file_path)

        if not self._old_style_mat:
            _units = self._data['Units']
            units = _parse_units(self._data, _units)

            # Extracting MutliElectrode field by field:
            _ME = self._data["MultiElectrode"]
            multi_electrode = dict((k, _ME.get(k)[()]) for k in _ME.keys())

            # Extracting sampling_frequency:
            sr = self._data["samplingRate"]
            sampling_frequency = float(_squeeze_ds(sr))

            # Remove noise units if necessary:
            if keep_good_only:
                units = [
                    unit for unit in units
                    if unit["ID"].flatten()[0].astype(int) % 1000 != 0
                ]

            if 'sortingInfo' in self._data.keys():
                info = self._data["sortingInfo"]
                start_frame = _squeeze_ds(info['startTimes'])
                self.start_frame = int(start_frame)
            else:
                self.start_frame = 0
        else:
            _units = self._getfield('Units').squeeze()
            fields = _units.dtype.fields.keys()
            units = []

            for unit in _units:
                unit_dict = {}
                for f in fields:
                    unit_dict[f] = unit[f]
                units.append(unit_dict)

            sr = self._getfield("samplingRate")
            sampling_frequency = float(_squeeze_ds(sr))

            _ME = self._data["MultiElectrode"]
            multi_electrode = dict(
                (k, _ME[k][0][0].T) for k in _ME.dtype.fields.keys())

            # Remove noise units if necessary:
            if keep_good_only:
                units = [
                    unit for unit in units
                    if unit["ID"].flatten()[0].astype(int) % 1000 != 0
                ]

            if 'sortingInfo' in self._data.keys():
                info = self._getfield("sortingInfo")
                start_frame = _squeeze_ds(info['startTimes'])
                self.start_frame = int(start_frame)
            else:
                self.start_frame = 0

        self._units = units
        self._multi_electrode = multi_electrode

        unit_ids = []
        spiketrains = []
        for uc, unit in enumerate(units):
            unit_id = int(_squeeze_ds(unit["ID"]))
            spike_times = _squeeze(
                unit["spikeTrain"]).astype('int64') - self.start_frame
            unit_ids.append(unit_id)
            spiketrains.append(spike_times)

        BaseSorting.__init__(self, sampling_frequency, unit_ids)

        self.add_sorting_segment(HDSortSortingSegment(unit_ids, spiketrains))

        # property
        templates = []
        templates_frames_cut_before = []
        for uc, unit in enumerate(units):
            if self._old_style_mat:
                template = unit["footprint"].T
            else:
                template = unit["footprint"]
            templates.append(template)
            templates_frames_cut_before.append(unit["cutLeft"].flatten())
        self.set_property("template", np.array(templates))
        self.set_property("template_frames_cut_before",
                          np.array(templates_frames_cut_before))

        self._kwargs = {
            'file_path': str(file_path),
            'keep_good_only': keep_good_only
        }
Пример #6
0
    def __init__(self, folder_path, sampling_frequency=30000):
        assert self.installed, self.installation_mesg
        # check correct parent folder:
        self._folder_path = Path(folder_path)
        if 'probe' not in self._folder_path.name:
            raise ValueError(
                'folder name should contain "probe", containing channels, clusters.* .npy datasets'
            )
        # load datasets as mmap into a dict:
        required_alf_datasets = ['spikes.times', 'spikes.clusters']
        found_alf_datasets = dict()
        for alf_dataset_name in self.file_loc.iterdir():
            if 'spikes' in alf_dataset_name.stem or 'clusters' in alf_dataset_name.stem:
                if 'npy' in alf_dataset_name.suffix:
                    dset = np.load(alf_dataset_name,
                                   mmap_mode='r',
                                   allow_pickle=True)
                    found_alf_datasets.update({alf_dataset_name.stem: dset})
                elif 'metrics' in alf_dataset_name.stem:
                    found_alf_datasets.update(
                        {alf_dataset_name.stem: pd.read_csv(alf_dataset_name)})

        # check existence of datasets:
        if not any([i in found_alf_datasets for i in required_alf_datasets]):
            raise Exception(
                f'could not find {required_alf_datasets} in folder')

        spike_clusters = found_alf_datasets['spikes.clusters']
        spike_times = found_alf_datasets['spikes.times']

        # load units properties:
        total_units = 0
        properties = dict()

        for alf_dataset_name, alf_dataset in found_alf_datasets.items():
            if 'clusters' in alf_dataset_name:
                if 'clusters.metrics' in alf_dataset_name:
                    for property_name, property_values in found_alf_datasets[
                            alf_dataset_name].iteritems():
                        properties[property_name] = property_values.tolist()
                else:
                    property_name = alf_dataset_name.split('.')[1]
                    properties[property_name] = alf_dataset
                    if total_units == 0:
                        total_units = alf_dataset.shape[0]

        if 'clusters.metrics' in found_alf_datasets and \
                found_alf_datasets['clusters.metrics'].get('cluster_id') is not None:
            unit_ids = found_alf_datasets['clusters.metrics'].get(
                'cluster_id').tolist()
        else:
            unit_ids = list(range(total_units))

        BaseSorting.__init__(self,
                             unit_ids=unit_ids,
                             sampling_frequency=sampling_frequency)
        sorting_segment = ALFSortingSegment(spike_clusters, spike_times,
                                            sampling_frequency)
        self.add_sorting_segment(sorting_segment)

        # add properties
        for property_name, values in properties.items():
            self.set_property(property_name, values)

        self._kwargs = {
            'folder_path': str(Path(folder_path).absolute()),
            'sampling_frequency': sampling_frequency
        }
Пример #7
0
    def __init__(self,
                 file_path: PathType,
                 electrical_series_name: str = None,
                 sampling_frequency: float = None,
                 samples_for_rate_estimation: int = 100000):
        check_nwb_install()
        self._file_path = str(file_path)
        self._electrical_series_name = electrical_series_name

        io = NWBHDF5IO(self._file_path, mode='r', load_namespaces=True)
        self._nwbfile = io.read()
        if sampling_frequency is None:
            # defines the electrical series from where the sorting came from
            # important to know the sampling_frequency
            self._es = get_electrical_series(self._nwbfile,
                                             self._electrical_series_name)
            # get rate
            timestamps = None
            if self._es.rate is not None:
                sampling_frequency = self._es.rate
            else:
                if hasattr(self._es, "timestamps"):
                    if self._es.timestamps is not None:
                        timestamps = self._es.timestamps
                        sampling_frequency = 1 / np.median(
                            np.diff(timestamps[samples_for_rate_estimation]))

        assert sampling_frequency is not None, "Couldn't load sampling frequency. Please provide it with the " \
                                               "'sampling_frequency' argument"

        # get all units ids
        units_ids = list(self._nwbfile.units.id[:])

        # store units properties and spike features to dictionaries
        properties = dict()

        for column in list(self._nwbfile.units.colnames):
            if column == 'spike_times':
                continue
            # if it is unit_property
            property_values = self._nwbfile.units[column][:]

            # only load columns with same shape for all units
            if np.all(p.shape == property_values[0].shape
                      for p in property_values):
                properties[column] = property_values
            else:
                print(
                    f"Skipping {column} because of unequal shapes across units"
                )

        BaseSorting.__init__(self,
                             sampling_frequency=sampling_frequency,
                             unit_ids=units_ids)
        sorting_segment = NwbSortingSegment(
            nwbfile=self._nwbfile,
            sampling_frequency=sampling_frequency,
            timestamps=timestamps)
        self.add_sorting_segment(sorting_segment)

        for prop_name, values in properties.items():
            self.set_property(prop_name, np.array(values))

        self._kwargs = {
            'file_path': str(Path(file_path).absolute()),
            'electrical_series_name': self._electrical_series_name,
            'sampling_frequency': sampling_frequency,
            'samples_for_rate_estimation': samples_for_rate_estimation
        }
Пример #8
0
    def __init__(self,
                 file_path,
                 electrical_series_name: str = None,
                 sampling_frequency: float = None):
        """
        Parameters
        ----------
        file_path: path to NWB file
        electrical_series_name: str with pynwb.ecephys.ElectricalSeries object name
        sampling_frequency: float
        """
        assert self.installed, self.installation_mesg
        self._file_path = str(file_path)
        with NWBHDF5IO(self._file_path, 'r') as io:
            nwbfile = io.read()
            if sampling_frequency is None:
                # defines the electrical series from where the sorting came from
                # important to know the sampling_frequency
                if electrical_series_name is None:
                    if len(nwbfile.acquisition) > 1:
                        raise Exception(
                            'More than one acquisition found. You must specify electrical_series_name.'
                        )
                    if len(nwbfile.acquisition) == 0:
                        raise Exception(
                            "No acquisitions found in the .nwb file from which to read sampling frequency. \
                                         Please, specify 'sampling_frequency' parameter."
                        )
                    es = list(nwbfile.acquisition.values())[0]
                else:
                    es = electrical_series_name
                # get rate
                if es.rate is not None:
                    sampling_frequency = es.rate
                else:
                    sampling_frequency = 1 / (es.timestamps[1] -
                                              es.timestamps[0])

            assert sampling_frequency is not None, "Couldn't load sampling frequency. Please provide it with the " \
                                                   "'sampling_frequency' argument"

            # get all units ids
            units_ids = list(nwbfile.units.id[:])

            # store units properties and spike features to dictionaries
            all_pr_ft = list(nwbfile.units.colnames)
            all_names = [i.name for i in nwbfile.units.columns]
            for item in all_pr_ft:
                if item == 'spike_times':
                    continue
                # test if item is a unit_property or a spike_feature
                if item + '_index' in all_names:  # if it has index, it is a spike_feature
                    pass
                else:  # if it is unit_property
                    properties = dict()
                    for u_id in units_ids:
                        ind = list(units_ids).index(u_id)
                        if isinstance(nwbfile.units[item][ind], pd.DataFrame):
                            prop_value = nwbfile.units[item][ind].index[0]
                        else:
                            prop_value = nwbfile.units[item][ind]

                        if item not in properties:
                            properties[item] = np.zeros(len(units_ids),
                                                        dtype=type(prop_value))

        BaseSorting.__init__(self,
                             sampling_frequency=sampling_frequency,
                             units_ids=units_ids)
        sorting_segment = NwbSortingSegment(
            path=self._file_path, sampling_frequency=sampling_frequency)
        self.add_sorting_segment(sorting_segment)

        for prop_name, values in properties.items():
            self.set_property(prop_name, values)

        self._kwargs = {
            'file_path': str(Path(file_path).absolute()),
            'electrical_series_name': electrical_series_name,
            'sampling_frequency': sampling_frequency
        }