Beispiel #1
0
    def __init__(self,
                 sampling_frequency=None,
                 use_natural_unit_ids=False,
                 **neo_kwargs):
        _NeoBaseExtractor.__init__(self, **neo_kwargs)

        self.use_natural_unit_ids = use_natural_unit_ids

        if sampling_frequency is None:
            sampling_frequency = self._auto_guess_sampling_frequency()

        spike_channels = self.neo_reader.header['spike_channels']

        if use_natural_unit_ids:
            unit_ids = spike_channels['id']
            assert np.unique(
                unit_ids
            ).size == unit_ids.size, 'unit_ids is have duplications'
        else:
            # use interger based unit_ids
            unit_ids = np.arange(spike_channels.size, dtype='int64')

        BaseSorting.__init__(self, sampling_frequency, unit_ids)

        nseg = self.neo_reader.segment_count(block_index=0)
        for segment_index in range(nseg):
            if self.handle_spike_frame_directly:
                t_start = None
            else:
                t_start = self.neo_reader.get_signal_t_start(0, segment_index)

            sorting_segment = NeoSortingSegment(self.neo_reader, segment_index,
                                                self.use_natural_unit_ids,
                                                t_start, sampling_frequency)
            self.add_sorting_segment(sorting_segment)
Beispiel #2
0
    def __init__(self, file_path, load_unit_info=True):
        assert self.installed, self.installation_mesg

        self._recording_file = file_path
        self._rf = h5py.File(self._recording_file, mode='r')
        if 'Sampling' in self._rf:
            if self._rf['Sampling'][()] == 0:
                sampling_frequency = None
            else:
                sampling_frequency = self._rf['Sampling'][()]

        spike_ids = self._rf['cluster_id'][()]
        unit_ids = np.unique(spike_ids)
        spike_times = self._rf['times'][()]

        if load_unit_info:
            self.load_unit_info()

        BaseSorting.__init__(self, sampling_frequency, unit_ids)
        self.add_sorting_segment(
            HerdingspikesSortingSegment(unit_ids, spike_times, spike_ids))
        self._kwargs = {
            'file_path': str(Path(file_path).absolute()),
            'load_unit_info': load_unit_info
        }
Beispiel #3
0
    def __init__(self, folder_path, chan_grp=None):
        try:
            import tridesclous as tdc
            HAVE_TDC = True
        except ImportError:
            HAVE_TDC = False
        assert HAVE_TDC, self.installation_mesg
        tdc_folder = Path(folder_path)

        dataio = tdc.DataIO(str(tdc_folder))
        if chan_grp is None:
            # if chan_grp is not provided, take the first one if unique
            chan_grps = list(dataio.channel_groups.keys())
            assert len(chan_grps) == 1, 'There are several groups in the folder, specify chan_grp=...'
            chan_grp = chan_grps[0]

        catalogue = dataio.load_catalogue(name='initial', chan_grp=chan_grp)

        labels = catalogue['clusters']['cluster_label']
        labels = labels[labels >= 0]
        unit_ids = list(labels)

        sampling_frequency = dataio.sample_rate

        BaseSorting.__init__(self, sampling_frequency, unit_ids)
        for seg_num in range(dataio.nb_segment):
            # load all spike in memory (this avoid to lock the folder with memmap throug dataio
            all_spikes = dataio.get_spikes(seg_num=seg_num, chan_grp=chan_grp, i_start=None, i_stop=None).copy()
            self.add_sorting_segment(TridesclousSortingSegment(all_spikes))

        self._kwargs = {'folder_path': str(Path(folder_path).absolute()), 'chan_grp': chan_grp}
    def __init__(self, oldapi_sorting_extractor):
        BaseSorting.__init__(self,
                             sampling_frequency=oldapi_sorting_extractor.
                             get_sampling_frequency(),
                             unit_ids=oldapi_sorting_extractor.get_unit_ids())

        sorting_segment = OldToNewSortingSegment(oldapi_sorting_extractor)
        self.add_sorting_segment(sorting_segment)
    def __init__(self, folder_path):
        assert HAVE_H5PY, self.installation_mesg

        spykingcircus_folder = Path(folder_path)
        listfiles = spykingcircus_folder.iterdir()

        parent_folder = None
        result_folder = None
        for f in listfiles:
            if f.is_dir():
                if any([f_.suffix == '.hdf5' for f_ in f.iterdir()]):
                    parent_folder = spykingcircus_folder
                    result_folder = f

        if parent_folder is None:
            parent_folder = spykingcircus_folder.parent
            for f in parent_folder.iterdir():
                if f.is_dir():
                    if any([f_.suffix == '.hdf5' for f_ in f.iterdir()]):
                        result_folder = spykingcircus_folder

        assert isinstance(parent_folder, Path) and isinstance(
            result_folder, Path), "Not a valid spyking circus folder"

        # load files
        results = None
        for f in result_folder.iterdir():
            if 'result.hdf5' in str(f):
                results = f
            if 'result-merged.hdf5' in str(f):
                results = f
                break
        if results is None:
            raise Exception(spykingcircus_folder,
                            " is not a spyking circus folder")

        # load params
        sample_rate = None
        for f in parent_folder.iterdir():
            if f.suffix == '.params':
                sample_rate = _load_sample_rate(f)

        assert sample_rate is not None, 'sample rate not found'

        with h5py.File(results, 'r') as f_results:
            spiketrains = []
            unit_ids = []
            for temp in f_results['spiketimes'].keys():
                spiketrains.append(
                    np.array(f_results['spiketimes'][temp]).astype('int64'))
                unit_ids.append(int(temp.split('_')[-1]))

        BaseSorting.__init__(self, sample_rate, unit_ids)
        self.add_sorting_segment(
            SpykingcircustSortingSegment(unit_ids, spiketrains))

        self._kwargs = {'folder_path': str(Path(folder_path).absolute())}
Beispiel #6
0
    def __init__(self, file_path, sampling_frequency):
        firings = readmda(str(file_path))
        labels = firings[2, :]
        unit_ids = np.unique(labels).astype(int)
        BaseSorting.__init__(self,
                             unit_ids=unit_ids,
                             sampling_frequency=sampling_frequency)

        sorting_segment = MdaSortingSegment(firings)
        self.add_sorting_segment(sorting_segment)

        self._kwargs = {
            'file_path': str(Path(file_path).absolute()),
            'sampling_frequency': sampling_frequency
        }
    def __init__(self, oldapi_sorting_extractor):
        BaseSorting.__init__(self,
                             sampling_frequency=oldapi_sorting_extractor.
                             get_sampling_frequency(),
                             unit_ids=oldapi_sorting_extractor.get_unit_ids())

        sorting_segment = OldToNewSortingSegment(oldapi_sorting_extractor)
        self.add_sorting_segment(sorting_segment)

        self.is_dumpable = False

        # add old properties
        copy_properties(oldapi_extractor=oldapi_sorting_extractor,
                        new_extractor=self)

        self._kwargs = {'oldapi_sorting_extractor': oldapi_sorting_extractor}
    def __init__(self, folder_path, sampling_frequency=None, user='******', det_sign='both', keep_good_only=True):

        folder_path = Path(folder_path)
        assert folder_path.is_dir(), 'Folder {} doesn\'t exist'.format(folder_path)
        if sampling_frequency is None:
            h5_path = str(folder_path) + '.h5'
            if Path(h5_path).exists():
                with h5py.File(h5_path, mode='r') as f:
                    sampling_frequency = f['sr'][0]

        # ~ self.set_sampling_frequency(sampling_frequency)
        det_file = str(folder_path / Path('data_' + folder_path.stem + '.h5'))
        sort_cat_files = []
        for sign in ['neg', 'pos']:
            if det_sign in ['both', sign]:
                sort_cat_file = folder_path / Path('sort_{}_{}/sort_cat.h5'.format(sign, user))
                if sort_cat_file.exists():
                    sort_cat_files.append((sign, str(sort_cat_file)))

        unit_counter = 0
        spiketrains = {}
        metadata = {}
        unsorted = []
        with h5py.File(det_file, mode='r') as fdet:
            for sign, sfile in sort_cat_files:
                with h5py.File(sfile, mode='r') as f:
                    sp_class = f['classes'][()]
                    gaux = f['groups'][()]
                    groups = {g: gaux[gaux[:, 1] == g, 0] for g in np.unique(gaux[:, 1])}  # array of classes per group
                    group_type = {group: g_type for group, g_type in f['types'][()]}
                    sp_index = f['index'][()]

                times_css = fdet[sign]['times'][()]
                for gr, cls in groups.items():
                    if keep_good_only and (group_type[gr] < 1):  # artifact or unsorted
                        continue
                    spiketrains[unit_counter] = np.rint(
                        times_css[sp_index[np.isin(sp_class, cls)]] * (sampling_frequency / 1000))
                    metadata[unit_counter] = {'group_type': group_type[gr]}
                    unit_counter = unit_counter + 1
        unit_ids = np.arange(unit_counter, dtype='int64')
        BaseSorting.__init__(self, sampling_frequency, unit_ids)
        self.add_sorting_segment(CombinatoSortingSegment(spiketrains))
        self.set_property('unsorted', np.array([metadata[u]['group_type'] == 0 for u in range(unit_counter)]))
        self.set_property('artifact', np.array([metadata[u]['group_type'] == -1 for u in range(unit_counter)]))
        self._kwargs = {'folder_path': str(folder_path), 'user': user, 'det_sign': det_sign}
Beispiel #9
0
    def __init__(self,
                 sampling_frequency,
                 multisortingcomparison,
                 min_agreement_count=1,
                 min_agreement_count_only=False):

        self._msc = multisortingcomparison
        self.is_dumpable = False

        # TODO: @alessio I leav this for you
        # if min_agreement_count_only:
        # self._unit_ids = list(u for u in self._msc._new_units.keys()
        # if self._msc._new_units[u]['agreement_number'] == min_agreement_count)
        # else:
        # self._unit_ids = list(u for u in self._msc._new_units.keys()
        # if self._msc._new_units[u]['agreement_number'] >= min_agreement_count)

        # for unit in self._unit_ids:
        # self.set_unit_property(unit_id=unit, property_name='agreement_number',
        # value=self._msc._new_units[unit]['agreement_number'])
        # self.set_unit_property(unit_id=unit, property_name='avg_agreement',
        # value=self._msc._new_units[unit]['avg_agreement'])
        # self.set_unit_property(unit_id=unit, property_name='sorter_unit_ids',
        # value=self._msc._new_units[unit]['sorter_unit_ids'])

        if min_agreement_count_only:
            unit_ids = list(u for u in self._msc._new_units.keys()
                            if self._msc._new_units[u]['agreement_number'] ==
                            min_agreement_count)
        else:
            unit_ids = list(u for u in self._msc._new_units.keys()
                            if self._msc._new_units[u]['agreement_number'] >=
                            min_agreement_count)

        BaseSorting.__init__(self,
                             sampling_frequency=sampling_frequency,
                             unit_ids=unit_ids)
        if len(unit_ids) > 0:
            for k in ('agreement_number', 'avg_agreement', 'sorter_unit_ids'):
                values = [
                    self._msc._new_units[unit_id][k] for unit_id in unit_ids
                ]
                self.set_property(k, values, ids=unit_ids)

        sorting_segment = AgreementSortingSegment(multisortingcomparison)
        self.add_sorting_segment(sorting_segment)
Beispiel #10
0
    def __init__(self, file_path, keep_good_only=True):
        MatlabHelper.__init__(self, file_path)

        cluster_classes = self._getfield("cluster_class")
        classes = cluster_classes[:, 0]
        spike_times = cluster_classes[:, 1]
        par = self._getfield("par")
        sampling_frequency = par[0, 0][np.where(np.array(par.dtype.names) == 'sr')[0][0]][0][0]
        unit_ids = np.unique(classes).astype('int')
        if keep_good_only:
            unit_ids = unit_ids[unit_ids > 0]
        spiketrains = {}
        for unit_id in unit_ids:
            mask = (classes == unit_id)
            spiketrains[unit_id] = np.rint(spike_times[mask] * (sampling_frequency / 1000))

        BaseSorting.__init__(self, sampling_frequency, unit_ids)

        self.add_sorting_segment(WaveClustSortingSegment(unit_ids, spiketrains))
        self.set_property('unsorted', np.array([c == 0 for c in unit_ids]))
        self._kwargs = {'file_path': str(Path(file_path).absolute())}
Beispiel #11
0
    def __init__(self, folder_path):
        assert HAVE_YASS, self.installation_mesg

        folder_path = Path(folder_path)

        self.fname_spike_train = folder_path / 'tmp' / 'output' / 'spike_train.npy'
        self.fname_templates = folder_path / 'tmp' / 'output' / 'templates' / 'templates_0sec.npy'
        self.fname_config = folder_path / 'config.yaml'

        # Read CONFIG File
        with open(self.fname_config, 'r') as stream:
            self.config = yaml.safe_load(stream)

        spiketrains = np.load(self.fname_spike_train)
        unit_ids = np.unique(spiketrains[:, 1])

        # initialize
        sampling_frequency = self.config['recordings']['sampling_rate']
        BaseSorting.__init__(self, sampling_frequency, unit_ids)
        self.add_sorting_segment(YassSortingSegment(spiketrains))

        self._kwargs = {'folder_path': str(folder_path)}
    def __init__(self, sampling_frequency, multisortingcomparison,
                 min_agreement_count=1, min_agreement_count_only=False):

        self._msc = multisortingcomparison
        self.is_dumpable = False

        if min_agreement_count_only:
            unit_ids = list(u for u in self._msc._new_units.keys()
                            if self._msc._new_units[u]['agreement_number'] == min_agreement_count)
        else:
            unit_ids = list(u for u in self._msc._new_units.keys()
                            if self._msc._new_units[u]['agreement_number'] >= min_agreement_count)

        BaseSorting.__init__(
            self, sampling_frequency=sampling_frequency, unit_ids=unit_ids)
        if len(unit_ids) > 0:
            for k in ('agreement_number', 'avg_agreement', 'unit_ids'):
                values = [self._msc._new_units[unit_id][k]
                          for unit_id in unit_ids]
                self.set_property(k, values, ids=unit_ids)

        sorting_segment = AgreementSortingSegment(multisortingcomparison)
        self.add_sorting_segment(sorting_segment)
Beispiel #13
0
    def __init__(self, file_path, sampling_frequency, delimiter=','):
        assert self.installed, self.installation_mesg
        assert Path(file_path
                    ).suffix == ".csv", "The 'file_path' should be a csv file!"

        if Path(file_path).is_file():
            spike_clusters = sbio.SpikeClusters()
            spike_clusters.fromCSV(str(file_path), None, delimiter=delimiter)
        else:
            raise FileNotFoundError(
                f'The ground truth file {file_path} could not be found')

        BaseSorting.__init__(self,
                             unit_ids=spike_clusters.keys(),
                             sampling_frequency=sampling_frequency)

        sorting_segment = SHYBRIDSortingSegment(spike_clusters)
        self.add_sorting_segment(sorting_segment)

        self._kwargs = {
            'file_path': str(Path(file_path).absolute()),
            'sampling_frequency': sampling_frequency,
            'delimiter': delimiter
        }
Beispiel #14
0
    def __init__(self, file_or_folder_path, exclude_cluster_groups=None):
        assert HAVE_H5PY, self.installation_mesg
        # ~ SortingExtractor.__init__(self)

        kwik_file_or_folder = Path(file_or_folder_path)
        kwikfile = None
        klustafolder = None
        if kwik_file_or_folder.is_file():
            assert kwik_file_or_folder.suffix == '.kwik', "Not a '.kwik' file"
            kwikfile = Path(kwik_file_or_folder).absolute()
            klustafolder = kwikfile.parent
        elif kwik_file_or_folder.is_dir():
            klustafolder = kwik_file_or_folder
            kwikfiles = [
                f for f in kwik_file_or_folder.iterdir() if f.suffix == '.kwik'
            ]
            if len(kwikfiles) == 1:
                kwikfile = kwikfiles[0]
        assert kwikfile is not None, "Could not load '.kwik' file"

        try:
            config_file = [
                f for f in klustafolder.iterdir() if f.suffix == '.prm'
            ][0]
            config = read_python(str(config_file))
            sampling_frequency = config['traces']['sample_rate']
        except Exception as e:
            print("Could not load sampling frequency info")

        kf_reader = h5py.File(kwikfile, 'r')
        spiketrains = []
        unit_ids = []
        unique_units = []
        klusta_units = []
        cluster_groups_name = []
        groups = []
        unit = 0

        cs_to_exclude = []
        valid_group_names = [
            i[1].lower() for i in self.default_cluster_groups.items()
        ]
        if exclude_cluster_groups is not None:
            assert isinstance(exclude_cluster_groups,
                              list), 'exclude_cluster_groups should be a list'
            for ec in exclude_cluster_groups:
                assert ec in valid_group_names, f'select exclude names out of: {valid_group_names}'
                cs_to_exclude.append(ec.lower())

        for channel_group in kf_reader.get('/channel_groups'):
            chan_cluster_id_arr = kf_reader.get(
                f'/channel_groups/{channel_group}/spikes/clusters/main')[()]
            chan_cluster_times_arr = kf_reader.get(
                f'/channel_groups/{channel_group}/spikes/time_samples')[()]
            chan_cluster_ids = np.unique(
                chan_cluster_id_arr)  # if clusters were merged in gui,
            # the original id's are still in the kwiktree, but
            # in this array

            for cluster_id in chan_cluster_ids:
                cluster_frame_idx = np.nonzero(
                    chan_cluster_id_arr ==
                    cluster_id)  # the [()] is a h5py thing
                st = chan_cluster_times_arr[cluster_frame_idx]
                assert st.shape[0] > 0, 'no spikes in cluster'
                cluster_group = kf_reader.get(
                    f'/channel_groups/{channel_group}/clusters/main/{cluster_id}'
                ).attrs['cluster_group']

                assert cluster_group in self.default_cluster_groups.keys(
                ), f'cluster_group not in "default_dict: {cluster_group}'
                cluster_group_name = self.default_cluster_groups[cluster_group]

                if cluster_group_name.lower() in cs_to_exclude:
                    continue

                spiketrains.append(st)

                klusta_units.append(int(cluster_id))
                unique_units.append(unit)
                unit += 1
                groups.append(int(channel_group))
                cluster_groups_name.append(cluster_group_name)

        if len(np.unique(klusta_units)) == len(np.unique(unique_units)):
            unit_ids = klusta_units
        else:
            print('Klusta units are not unique! Using unique unit ids')
            unit_ids = unique_units

        BaseSorting.__init__(self, sampling_frequency, unit_ids)
        self.is_dumpable = False

        self.add_sorting_segment(KlustSortingSegment(unit_ids, spiketrains))

        self.set_property('group', groups)
        quality = [e.lower() for e in cluster_groups_name]
        self.set_property('quality', quality)

        self._kwargs = {
            'file_or_folder_path': str(Path(file_or_folder_path).absolute())
        }
    def __init__(self,
                 file_path,
                 electrical_series_name: str = None,
                 sampling_frequency: float = None):
        """
        Parameters
        ----------
        file_path: path to NWB file
        electrical_series_name: str with pynwb.ecephys.ElectricalSeries object name
        sampling_frequency: float
        """
        assert self.installed, self.installation_mesg
        self._file_path = str(file_path)
        with NWBHDF5IO(self._file_path, 'r') as io:
            nwbfile = io.read()
            if sampling_frequency is None:
                # defines the electrical series from where the sorting came from
                # important to know the sampling_frequency
                if electrical_series_name is None:
                    if len(nwbfile.acquisition) > 1:
                        raise Exception(
                            'More than one acquisition found. You must specify electrical_series_name.'
                        )
                    if len(nwbfile.acquisition) == 0:
                        raise Exception(
                            "No acquisitions found in the .nwb file from which to read sampling frequency. \
                                         Please, specify 'sampling_frequency' parameter."
                        )
                    es = list(nwbfile.acquisition.values())[0]
                else:
                    es = electrical_series_name
                # get rate
                if es.rate is not None:
                    sampling_frequency = es.rate
                else:
                    sampling_frequency = 1 / (es.timestamps[1] -
                                              es.timestamps[0])

            assert sampling_frequency is not None, "Couldn't load sampling frequency. Please provide it with the " \
                                                   "'sampling_frequency' argument"

            # get all units ids
            units_ids = list(nwbfile.units.id[:])

            # store units properties and spike features to dictionaries
            all_pr_ft = list(nwbfile.units.colnames)
            all_names = [i.name for i in nwbfile.units.columns]
            for item in all_pr_ft:
                if item == 'spike_times':
                    continue
                # test if item is a unit_property or a spike_feature
                if item + '_index' in all_names:  # if it has index, it is a spike_feature
                    pass
                else:  # if it is unit_property
                    properties = dict()
                    for u_id in units_ids:
                        ind = list(units_ids).index(u_id)
                        if isinstance(nwbfile.units[item][ind], pd.DataFrame):
                            prop_value = nwbfile.units[item][ind].index[0]
                        else:
                            prop_value = nwbfile.units[item][ind]

                        if item not in properties:
                            properties[item] = np.zeros(len(units_ids),
                                                        dtype=type(prop_value))

        BaseSorting.__init__(self,
                             sampling_frequency=sampling_frequency,
                             units_ids=units_ids)
        sorting_segment = NwbSortingSegment(
            path=self._file_path, sampling_frequency=sampling_frequency)
        self.add_sorting_segment(sorting_segment)

        for prop_name, values in properties.items():
            self.set_property(prop_name, values)

        self._kwargs = {
            'file_path': str(Path(file_path).absolute()),
            'electrical_series_name': electrical_series_name,
            'sampling_frequency': sampling_frequency
        }
Beispiel #16
0
 def __init__(self, sampling_frequency, unit_ids=[]):
     BaseSorting.__init__(self, sampling_frequency, unit_ids)
     self.is_dumpable = False
Beispiel #17
0
    def __init__(self,
                 folder_path,
                 exclude_cluster_groups=None,
                 keep_good_only=False):
        try:
            import pandas as pd
            HAVE_PD = True
        except ImportError:
            HAVE_PD = False
        assert HAVE_PD, self.installation_mesg

        phy_folder = Path(folder_path)

        spike_times = np.load(phy_folder / 'spike_times.npy')
        spike_templates = np.load(phy_folder / 'spike_templates.npy')

        if (phy_folder / 'spike_clusters.npy').is_file():
            spike_clusters = np.load(phy_folder / 'spike_clusters.npy')
        else:
            spike_clusters = spike_templates

        clust_id = np.unique(spike_clusters)
        unit_ids = list(clust_id)
        spike_times.astype(int)
        params = read_python(str(phy_folder / 'params.py'))
        sampling_frequency = params['sample_rate']

        # try to load cluster info
        cluster_info_files = [
            p for p in phy_folder.iterdir()
            if p.suffix in ['.csv', '.tsv'] and "cluster_info" in p.name
        ]

        if len(cluster_info_files) == 1:
            # load properties from cluster_info file
            cluster_info_file = cluster_info_files[0]
            if cluster_info_file.suffix == ".tsv":
                delimeter = "\t"
            else:
                delimeter = ","
            cluster_info = pd.read_csv(cluster_info_file, delimiter=delimeter)
        else:
            # load properties from other tsv/csv files
            all_property_files = [
                p for p in phy_folder.iterdir()
                if p.suffix in ['.csv', '.tsv']
            ]

            cluster_info = None
            for file in all_property_files:
                if file.suffix == ".tsv":
                    delimeter = "\t"
                else:
                    delimeter = ","
                new_property = pd.read_csv(file, delimiter=delimeter)
                if cluster_info is None:
                    cluster_info = new_property
                else:
                    cluster_info = pd.merge(cluster_info,
                                            new_property,
                                            on='cluster_id')

        # in case no tsv/csv files are found populate cluster info with minimal info
        if cluster_info is None:
            cluster_info = pd.DataFrame({'cluster_id': unit_ids})
            cluster_info['group'] = ['unsorted'] * len(unit_ids)

        if exclude_cluster_groups is not None:
            if isinstance(exclude_cluster_groups, str):
                cluster_info = cluster_info.query(
                    f"group != '{exclude_cluster_groups}'")
            elif isinstance(exclude_cluster_groups, list):
                if len(exclude_cluster_groups) > 0:
                    for exclude_group in exclude_cluster_groups:
                        cluster_info = cluster_info.query(
                            f"group != '{exclude_group}'")

        if keep_good_only and "KSLabel" in cluster_info.columns:
            cluster_info = cluster_info.query("KSLabel == 'good'")

        if "cluster_id" not in cluster_info.columns:
            assert "id" in cluster_info.columns, "Couldn't find cluster ids in the tsv files!"
            cluster_info["cluster_id"] = cluster_info["id"]
            del cluster_info["id"]

        if 'si_unit_id' in cluster_info.columns:
            unit_ids = cluster_info["si_unit_id"].values
            del cluster_info["si_unit_id"]
        else:
            unit_ids = cluster_info["cluster_id"].values

        BaseSorting.__init__(self, sampling_frequency, unit_ids)

        del cluster_info["cluster_id"]
        for prop_name in cluster_info.columns:
            if prop_name in ['chan_grp', 'ch_group']:
                self.set_property(key="group", values=cluster_info[prop_name])
            elif prop_name != "group":
                self.set_property(key=prop_name,
                                  values=cluster_info[prop_name])
            elif prop_name == "group":
                # rename group property to 'quality'
                self.set_property(key="quality",
                                  values=cluster_info[prop_name])

        self.add_sorting_segment(PhySortingSegment(spike_times,
                                                   spike_clusters))
Beispiel #18
0
    def __init__(self,
                 file_path: PathType,
                 electrical_series_name: str = None,
                 sampling_frequency: float = None,
                 samples_for_rate_estimation: int = 100000):
        check_nwb_install()
        self._file_path = str(file_path)
        self._electrical_series_name = electrical_series_name

        io = NWBHDF5IO(self._file_path, mode='r', load_namespaces=True)
        self._nwbfile = io.read()
        if sampling_frequency is None:
            # defines the electrical series from where the sorting came from
            # important to know the sampling_frequency
            self._es = get_electrical_series(self._nwbfile,
                                             self._electrical_series_name)
            # get rate
            timestamps = None
            if self._es.rate is not None:
                sampling_frequency = self._es.rate
            else:
                if hasattr(self._es, "timestamps"):
                    if self._es.timestamps is not None:
                        timestamps = self._es.timestamps
                        sampling_frequency = 1 / np.median(
                            np.diff(timestamps[samples_for_rate_estimation]))

        assert sampling_frequency is not None, "Couldn't load sampling frequency. Please provide it with the " \
                                               "'sampling_frequency' argument"

        # get all units ids
        units_ids = list(self._nwbfile.units.id[:])

        # store units properties and spike features to dictionaries
        properties = dict()

        for column in list(self._nwbfile.units.colnames):
            if column == 'spike_times':
                continue
            # if it is unit_property
            property_values = self._nwbfile.units[column][:]

            # only load columns with same shape for all units
            if np.all(p.shape == property_values[0].shape
                      for p in property_values):
                properties[column] = property_values
            else:
                print(
                    f"Skipping {column} because of unequal shapes across units"
                )

        BaseSorting.__init__(self,
                             sampling_frequency=sampling_frequency,
                             unit_ids=units_ids)
        sorting_segment = NwbSortingSegment(
            nwbfile=self._nwbfile,
            sampling_frequency=sampling_frequency,
            timestamps=timestamps)
        self.add_sorting_segment(sorting_segment)

        for prop_name, values in properties.items():
            self.set_property(prop_name, np.array(values))

        self._kwargs = {
            'file_path': str(Path(file_path).absolute()),
            'electrical_series_name': self._electrical_series_name,
            'sampling_frequency': sampling_frequency,
            'samples_for_rate_estimation': samples_for_rate_estimation
        }
Beispiel #19
0
    def __init__(self, folder_path, sampling_frequency=30000):
        assert self.installed, self.installation_mesg
        # check correct parent folder:
        self._folder_path = Path(folder_path)
        if 'probe' not in self._folder_path.name:
            raise ValueError(
                'folder name should contain "probe", containing channels, clusters.* .npy datasets'
            )
        # load datasets as mmap into a dict:
        required_alf_datasets = ['spikes.times', 'spikes.clusters']
        found_alf_datasets = dict()
        for alf_dataset_name in self.file_loc.iterdir():
            if 'spikes' in alf_dataset_name.stem or 'clusters' in alf_dataset_name.stem:
                if 'npy' in alf_dataset_name.suffix:
                    dset = np.load(alf_dataset_name,
                                   mmap_mode='r',
                                   allow_pickle=True)
                    found_alf_datasets.update({alf_dataset_name.stem: dset})
                elif 'metrics' in alf_dataset_name.stem:
                    found_alf_datasets.update(
                        {alf_dataset_name.stem: pd.read_csv(alf_dataset_name)})

        # check existence of datasets:
        if not any([i in found_alf_datasets for i in required_alf_datasets]):
            raise Exception(
                f'could not find {required_alf_datasets} in folder')

        spike_clusters = found_alf_datasets['spikes.clusters']
        spike_times = found_alf_datasets['spikes.times']

        # load units properties:
        total_units = 0
        properties = dict()

        for alf_dataset_name, alf_dataset in found_alf_datasets.items():
            if 'clusters' in alf_dataset_name:
                if 'clusters.metrics' in alf_dataset_name:
                    for property_name, property_values in found_alf_datasets[
                            alf_dataset_name].iteritems():
                        properties[property_name] = property_values.tolist()
                else:
                    property_name = alf_dataset_name.split('.')[1]
                    properties[property_name] = alf_dataset
                    if total_units == 0:
                        total_units = alf_dataset.shape[0]

        if 'clusters.metrics' in found_alf_datasets and \
                found_alf_datasets['clusters.metrics'].get('cluster_id') is not None:
            unit_ids = found_alf_datasets['clusters.metrics'].get(
                'cluster_id').tolist()
        else:
            unit_ids = list(range(total_units))

        BaseSorting.__init__(self,
                             unit_ids=unit_ids,
                             sampling_frequency=sampling_frequency)
        sorting_segment = ALFSortingSegment(spike_clusters, spike_times,
                                            sampling_frequency)
        self.add_sorting_segment(sorting_segment)

        # add properties
        for property_name, values in properties.items():
            self.set_property(property_name, values)

        self._kwargs = {
            'folder_path': str(Path(folder_path).absolute()),
            'sampling_frequency': sampling_frequency
        }
Beispiel #20
0
    def __init__(self, file_path, keep_good_only=True):
        MatlabHelper.__init__(self, file_path)

        if not self._old_style_mat:
            _units = self._data['Units']
            units = _parse_units(self._data, _units)

            # Extracting MutliElectrode field by field:
            _ME = self._data["MultiElectrode"]
            multi_electrode = dict((k, _ME.get(k)[()]) for k in _ME.keys())

            # Extracting sampling_frequency:
            sr = self._data["samplingRate"]
            sampling_frequency = float(_squeeze_ds(sr))

            # Remove noise units if necessary:
            if keep_good_only:
                units = [
                    unit for unit in units
                    if unit["ID"].flatten()[0].astype(int) % 1000 != 0
                ]

            if 'sortingInfo' in self._data.keys():
                info = self._data["sortingInfo"]
                start_frame = _squeeze_ds(info['startTimes'])
                self.start_frame = int(start_frame)
            else:
                self.start_frame = 0
        else:
            _units = self._getfield('Units').squeeze()
            fields = _units.dtype.fields.keys()
            units = []

            for unit in _units:
                unit_dict = {}
                for f in fields:
                    unit_dict[f] = unit[f]
                units.append(unit_dict)

            sr = self._getfield("samplingRate")
            sampling_frequency = float(_squeeze_ds(sr))

            _ME = self._data["MultiElectrode"]
            multi_electrode = dict(
                (k, _ME[k][0][0].T) for k in _ME.dtype.fields.keys())

            # Remove noise units if necessary:
            if keep_good_only:
                units = [
                    unit for unit in units
                    if unit["ID"].flatten()[0].astype(int) % 1000 != 0
                ]

            if 'sortingInfo' in self._data.keys():
                info = self._getfield("sortingInfo")
                start_frame = _squeeze_ds(info['startTimes'])
                self.start_frame = int(start_frame)
            else:
                self.start_frame = 0

        self._units = units
        self._multi_electrode = multi_electrode

        unit_ids = []
        spiketrains = []
        for uc, unit in enumerate(units):
            unit_id = int(_squeeze_ds(unit["ID"]))
            spike_times = _squeeze(
                unit["spikeTrain"]).astype('int64') - self.start_frame
            unit_ids.append(unit_id)
            spiketrains.append(spike_times)

        BaseSorting.__init__(self, sampling_frequency, unit_ids)

        self.add_sorting_segment(HDSortSortingSegment(unit_ids, spiketrains))

        # property
        templates = []
        templates_frames_cut_before = []
        for uc, unit in enumerate(units):
            if self._old_style_mat:
                template = unit["footprint"].T
            else:
                template = unit["footprint"]
            templates.append(template)
            templates_frames_cut_before.append(unit["cutLeft"].flatten())
        self.set_property("template", np.array(templates))
        self.set_property("template_frames_cut_before",
                          np.array(templates_frames_cut_before))

        self._kwargs = {
            'file_path': str(file_path),
            'keep_good_only': keep_good_only
        }
Beispiel #21
0
    def __init__(self,
                 folder_path,
                 exclude_cluster_groups=None,
                 keep_good_only=False):
        try:
            import pandas as pd
            HAVE_PD = True
        except ImportError:
            HAVE_PD = False
        assert HAVE_PD, self.installation_mesg

        phy_folder = Path(folder_path)

        spike_times = np.load(phy_folder / 'spike_times.npy')
        spike_templates = np.load(phy_folder / 'spike_templates.npy')

        if (phy_folder / 'spike_clusters.npy').is_file():
            spike_clusters = np.load(phy_folder / 'spike_clusters.npy')
        else:
            spike_clusters = spike_templates

        clust_id = np.unique(spike_clusters)
        unit_ids = list(clust_id)
        spike_times.astype(int)
        params = read_python(str(phy_folder / 'params.py'))
        sampling_frequency = params['sample_rate']

        # try to load cluster info
        cluster_info_files = [
            p for p in phy_folder.iterdir()
            if p.suffix in ['.csv', '.tsv'] and "cluster_info" in p.name
        ]
        if len(cluster_info_files) == 1:
            cluster_info_file = cluster_info_files[0]
            if cluster_info_file.suffix == ".tsv":
                delimeter = "\t"
            else:
                delimeter = ","
            cluster_info = pd.read_csv(cluster_info_file, delimiter=delimeter)

        else:
            all_property_files = [
                p for p in phy_folder.iterdir()
                if p.suffix in ['.csv', '.tsv']
            ]

            cluster_info = None
            for file in all_property_files:
                if file.suffix == ".tsv":
                    delimeter = "\t"
                else:
                    delimeter = ","
                new_property = pd.read_csv(file, delimiter=delimeter)
                if cluster_info is None:
                    cluster_info = new_property
                else:
                    cluster_info = pd.merge(cluster_info,
                                            new_property,
                                            on='cluster_id')

            cluster_info["id"] = cluster_info["cluster_id"]
            del cluster_info["cluster_id"]

        if exclude_cluster_groups is not None:
            if isinstance(exclude_cluster_groups, str):
                cluster_info = cluster_info.query(
                    f"group != '{exclude_cluster_groups}'")
            elif isinstance(exclude_cluster_groups, list):
                if len(exclude_cluster_groups) > 0:
                    for exclude_group in exclude_cluster_groups:
                        cluster_info = cluster_info.query(
                            f"group != '{exclude_group}'")

        if keep_good_only and "KSLabel" in cluster_info.columns:
            cluster_info = cluster_info.query(f"KSLabel != 'good'")

        unit_ids = cluster_info["id"].values
        BaseSorting.__init__(self, sampling_frequency, unit_ids)

        for prop_name in cluster_info.columns:
            if prop_name in ['chan_grp', 'ch_group']:
                self.set_property(key="group", values=cluster_info[prop_name])
            elif prop_name != "group":
                self.set_property(key=prop_name,
                                  values=cluster_info[prop_name])

        self.add_sorting_segment(PhySortingSegment(spike_times,
                                                   spike_clusters))

        self._kwargs = {
            'folder_path': str(Path(folder_path).absolute()),
            'exclude_cluster_groups': exclude_cluster_groups
        }