Пример #1
0
def scatter_amp_depth_fr_plot(spike_amps,
                              spike_clusters,
                              spike_depths,
                              spike_times,
                              cmap='hot',
                              display=False):
    """
    Prepare data for 2D scatter plot of cluster depth vs cluster amp with colour indicating cluster
    firing rate

    :param spike_amps:
    :param spike_clusters:
    :param spike_depths:
    :param spike_times:
    :param cmap:
    :param display: generate figure
    :return: ScatterPlot object, if display=True also returns matplotlib fig and ax objects
    """

    cluster, cluster_depth, n_cluster = compute_cluster_average(
        spike_clusters, spike_depths)
    _, cluster_amp, _ = compute_cluster_average(spike_clusters, spike_amps)
    cluster_amp = cluster_amp * 1e6
    cluster_fr = n_cluster / np.max(spike_times)

    data = ScatterPlot(x=cluster_amp, y=cluster_depth, c=cluster_fr, cmap=cmap)
    data.set_xlim((0.9 * np.min(cluster_amp), 1.1 * np.max(cluster_amp)))

    if display:
        fig, ax = plot_scatter(data.convert2dict())
        return data.convert2dict(), fig, ax

    return data
Пример #2
0
 def prepare_data(self, spikes, clusters, trials):
     self.spikes = spikes
     self.clusters = clusters
     self.trials = trials
     self.ids = np.unique(spikes.clusters)
     self.ids_idx = np.arange(len(self.ids))
     self.metrics = np.array(clusters.metrics.ks2_label[self.ids_idx])
     self.colours = np.array(clusters.metrics.ks2_label[self.ids_idx])
     self.colours[np.where(self.colours != 'good')[0]] = QtGui.QColor(
         '#fdc086')  # default for noise and nan
     self.colours[np.where(
         self.colours == 'good')[0]] = QtGui.QColor('#7fc97f')
     self.locations = clusters.locations
     _, self.depths, self.nspikes = compute_cluster_average(
         spikes.clusters, spikes.depths)
     self.depths[np.where(np.isnan(
         self.depths))[0]] = self.clusters.depths[np.where(
             np.isnan(self.depths))[0]]  # use channel depth as default
     _, self.amps, _ = compute_cluster_average(spikes.clusters, spikes.amps)
     self.amps = self.amps * 1e6
     self.sort_by_id = np.arange(len(self.ids))
     self.sort_by_nspikes = np.argsort(self.nspikes)
     self.sort_by_nspikes = self.sort_by_nspikes[::-1]
     self.sort_by_good = np.append(
         np.where(self.metrics == 'good')[0],
         np.where(self.metrics == 'mua')[0])
     self.n_trials = len(trials['contrastRight'])
Пример #3
0
    def test_compute_cluster_averag(self):
        # Create fake data for 3 clusters
        clust1 = np.ones(40)
        clust1_vals = np.ones(40) * 200
        clust2 = 2 * np.ones(40)
        clust2_vals = np.r_[np.ones(20) * 300, np.ones(20) * 500]
        clust100 = 100 * np.ones(50)
        clust100_vals = np.r_[np.ones(25) * 0.5, np.ones(25) * 1.0]

        # Concatenate data for 3 clusters together
        spike_clust = np.r_[clust1, clust2, clust100]
        spike_val = np.r_[clust1_vals, clust2_vals, clust100_vals]

        # Shuffle the data to make order random
        ind = np.arange(len(spike_clust))
        np.random.shuffle(ind)
        spike_clust = spike_clust[ind]
        spike_val = spike_val[ind]
        # Make sure the data you have created is correct dimension
        assert (len(spike_clust) == len(spike_val))

        # Compute the average value across clusters
        clust, avg_val, count = processing.compute_cluster_average(
            spike_clust, spike_val)

        # Check output is as expected
        assert (np.all(clust == (1, 2, 100)))
        assert (avg_val[0] == 200)
        assert (avg_val[1] == 400)
        assert (avg_val[2] == 0.75)
        assert (np.all(count == (40, 40, 50)))
Пример #4
0
def scatter_amp_depth_fr_plot(spike_amps,
                              spike_clusters,
                              spike_depths,
                              spike_times,
                              cmap='hot',
                              display=False,
                              title=None,
                              **kwargs):
    """
    Prepare data for 2D scatter plot of cluster depth vs cluster amp with colour indicating cluster
    firing rate

    :param spike_amps:
    :param spike_clusters:
    :param spike_depths:
    :param spike_times:
    :param cmap:
    :param display: generate figure
    :return: ScatterPlot object, if display=True also returns matplotlib fig and ax objects
    """

    title = title or 'Cluster depth vs amp vs firing rate'
    cluster, cluster_depth, n_cluster = compute_cluster_average(
        spike_clusters, spike_depths)
    _, cluster_amp, _ = compute_cluster_average(spike_clusters, spike_amps)
    cluster_amp = cluster_amp * 1e6
    cluster_fr = n_cluster / np.max(spike_times)

    data = ScatterPlot(x=cluster_amp, y=cluster_depth, c=cluster_fr, cmap=cmap)
    data.set_xlim((0.9 * np.min(cluster_amp), 1.1 * np.max(cluster_amp)))
    data.set_labels(title=title,
                    xlabel='Cluster Amplitude (uV)',
                    ylabel='Distance from probe tip (um)',
                    clabel='Firing rate (Hz)')
    if display:
        ax, fig = plot_scatter(data.convert2dict(), **kwargs)
        return data.convert2dict(), fig, ax

    return data
Пример #5
0
    def __init__(self, eid, probe, one=None, spike_collection=None):
        one = one or ONE()

        if spike_collection == '':
            collection = f'alf/{probe}'
        elif spike_collection:
            collection = f'alf/{probe}/{spike_collection}'
        else:
            # Pykilosort is default, if not present look for normal kilosort
            all_collections = one.list_collections(eid)

            if f'alf/{probe}/pykilosort' in all_collections:
                collection = f'alf/{probe}/pykilosort'
            else:
                collection = f'alf/{probe}'

        try:
            self.spikes = one.load_object(eid, obj='spikes', collection=collection,
                                          attribute=['clusters', 'times', 'amps', 'depths'])
            self.clusters = one.load_object(eid, obj='clusters', collection=collection,
                                            attribute=['metrics', 'waveforms'])

        except alf.exceptions.ALFObjectNotFound:
            logger.error(f'Could not load spike sorting for session: {eid} and probe: {probe}, GUI'
                         f' will not work')
            raise

        # Get everything we need for the clusters
        # need to get rid of nans in amps and depths
        self.clusters.clust_ids, self.clusters.depths, n_spikes = \
            compute_cluster_average(self.spikes.clusters[~np.isnan(self.spikes.depths)],
                                    self.spikes.depths[~np.isnan(self.spikes.depths)])
        _, self.clusters.amps, _ = compute_cluster_average(
            self.spikes.clusters[~np.isnan(self.spikes.amps)],
            self.spikes.amps[~np.isnan(self.spikes.amps)])

        # If we don't have metrics file, we can't assign colours
        if not any(self.clusters.get('metrics', [None])):
            col = np.full((len(self.clusters.clust_ids)), colours['no metric'])
            self.clusters.colours_ks = col
            self.clusters.colours_ibl = col
            self.clusters['KS good'] = np.arange(len(self.clusters.clust_ids))
            self.clusters['IBL good'] = np.arange(len(self.clusters.clust_ids))
        elif not any(self.clusters.metrics.get('ks2_label', [None])):
            col = np.full((len(self.clusters.clust_ids)), colours['no metric'])
            self.clusters.colours_ks = col

            colours_ibl = np.array(self.clusters.metrics.ks2_label[self.clusters.clust_ids])
            good_ibl = np.where(self.clusters.metrics.label[self.clusters.clust_ids] == 1)[0]
            colours_ibl[good_ibl] = colours['IBL good']
            bad_ibl = np.where(self.clusters.metrics.label[self.clusters.clust_ids] != 1)[0]
            colours_ibl[bad_ibl] = colours['IBL bad']
            self.clusters.colours_ibl = colours_ibl
            self.clusters.metrics['amp_median'] *= 1e6

            self.clusters['KS good'] = np.arange(len(self.clusters.clust_ids))
            self.clusters['IBL good'] = np.r_[good_ibl, bad_ibl]
        else:
            # KS2 good mua colours
            # Bug in the KS2 units, in some cases, the clusters that do not exist in
            # spikes.clusters have not been filled with nans like the other features in metrics
            # Signature of these is for last values in ks_label to be None
            if len(np.where(self.clusters.metrics.ks2_label.values == None)[0]) > 0: # noqa
                colours_ks = np.array(self.clusters.metrics.ks2_label
                                      [:len(self.clusters.clust_ids)])
            else:
                colours_ks = np.array(self.clusters.metrics.ks2_label[self.clusters.clust_ids])
            good_ks = np.where(colours_ks == 'good')[0]
            colours_ks[good_ks] = colours['KS good']
            mua_ks = np.where(colours_ks == 'mua')[0]
            colours_ks[mua_ks] = colours['KS mua']
            self.clusters.colours_ks = colours_ks

            # IBL good bad colours
            colours_ibl = np.array(self.clusters.metrics.ks2_label[self.clusters.clust_ids])
            good_ibl = np.where(self.clusters.metrics.label[self.clusters.clust_ids] == 1)[0]
            colours_ibl[good_ibl] = colours['IBL good']
            bad_ibl = np.where(self.clusters.metrics.label[self.clusters.clust_ids] != 1)[0]
            colours_ibl[bad_ibl] = colours['IBL bad']
            self.clusters.colours_ibl = colours_ibl
            self.clusters.metrics['amp_median'] *= 1e6

            self.clusters['KS good'] = np.r_[good_ks, mua_ks]
            self.clusters['IBL good'] = np.r_[good_ibl, bad_ibl]

        # Some extra info for sorting clusters
        self.clusters['ids'] = np.arange(len(self.clusters.clust_ids))
        self.clusters['n spikes'] = np.argsort(n_spikes)[::-1]

        # Get trial data
        self.trials = one.load_object(eid, obj='trials', collection='alf')
        self.n_trials = self.trials['probabilityLeft'].shape[0]
        self.trial_events = [key for key in self.trials.keys() if 'time' in key]

        # Get behaviour data
        wheel = one.load_object(eid, obj='wheel', collection='alf')
        dlc_left = one.load_object(eid, obj='leftCamera', collection='alf',
                                   attribute=['times', 'dlc'])
        dlc_right = one.load_object(eid, obj='rightCamera', collection='alf',
                                    attribute=['times', 'dlc'])

        (self.behav, self.behav_events,
         self.dlc_aligned) = self.combine_behaviour_data(wheel, dlc_left, dlc_right)

        self.sess_qc = self.get_qc_info(eid, one)

        self.spikes_raster = Bunch()
        self.spikes_raster_psth = Bunch()
        self.behav_raster = Bunch()
        self.behav_raster_psth = Bunch()