Ejemplo n.º 1
0
    def plot_isi(self, catalogueconstructor, num_bins=100):
        """Plot ISI of spiketrains."""

        us_per_tick = int(1e6 / self.dataio.sample_rate)

        spiketrains = get_spiketrains(catalogueconstructor, us_per_tick,
                                      int(self.config['start_time'] * 1e6))

        num_clusters = len(spiketrains)

        n = 2
        if num_clusters > n * n:
            print("WARNING: Only {} out of {} available plots can be shown."
                  "".format(n * n, num_clusters))

        fig = Figure(figsize=(5, 5))
        canvas = FigureCanvas(fig)
        axes = fig.subplots(2, 2)
        for i in range(n):
            for j in range(n):
                if len(spiketrains) == 0:
                    axes[i, j].plot(0, 0)
                else:
                    cluster_label, cluster_trains = spiketrains.popitem()
                    intervals = np.diff(cluster_trains)
                    counts, bin_edges = np.histogram(intervals, bins=num_bins)
                    color = catalogueconstructor.colors[cluster_label]
                    x = np.ravel([bin_edges[:-1], bin_edges[1:]], 'F')
                    y = np.ravel([counts, counts], 'F')
                    axes[i, j].fill_between(x,
                                            0,
                                            y,
                                            facecolor=color,
                                            edgecolor='k')
                    # When there is an extremely long ISI, the label causes
                    # the figure to be too large.
                    # axes[i, j].text(0.7 * bin_edges[-1],
                    #                 0.7 * np.max(counts),
                    #                 cluster_label, color=color, fontsize=28)
                axes[i, j].axis('off')

        fig.subplots_adjust(wspace=0, hspace=0)
        return canvas
Ejemplo n.º 2
0
    def save_spiketimes(self):
        if not self.check_finished_loading():
            return

        path = os.path.join(self.output_path, 'spiketimes')
        if not os.path.exists(path):
            os.makedirs(path)

        d = {i: j for i, j in enumerate(['A', 'B', 'C', 'D', 'E', 'F', 'G'])}

        saved_channels = []
        for channel_idx, electrode in self.electrodes.items():
            if not electrode.get():
                continue

            label = self.labels[channel_idx]

            saved_channels.append(label)

            cc = self.run_spikesorter_on_channel(channel_idx, label)

            if cc is None:
                continue

            s_per_tick = 1 / self.dataio.sample_rate

            spiketrains = get_spiketrains(cc, s_per_tick,
                                          self.config['start_time'])

            for cell_label, spike_times in spiketrains.items():
                filename = 'ch{}_{}.mat'.format(label, d[cell_label])
                filepath = os.path.join(path, filename)
                savemat(filepath, {'timestamps': spike_times})

        self.log("Spike times of channels {} saved to {}.".format(
            saved_channels, path))
Ejemplo n.º 3
0
    def plot_raster(self, catalogueconstructor):

        us_per_tick = int(1e6 / self.dataio.sample_rate)

        spiketrains = get_spiketrains(catalogueconstructor, us_per_tick,
                                      int(self.config['start_time'] * 1e6))

        trigger_times = get_trigger_times(self.filepath, self.trigger_filename)

        num_clusters = len(spiketrains)

        num_triggers = len(trigger_times)

        if num_triggers == 0:
            duration = self.config['duration']
            num_triggers = int(np.sqrt(duration))
            trigger_times = np.linspace(0,
                                        duration,
                                        num_triggers,
                                        endpoint=False,
                                        dtype=int)
            trigger_times *= int(1e6)  # Seconds to microseconds.

        pre = int(self.plot_settings['pre'].get() * 1e6)
        post = int(self.plot_settings['post'].get() * 1e6)

        n = 2
        if num_clusters > n * n:
            print("WARNING: Only {} out of {} available plots can be shown."
                  "".format(n * n, num_clusters))

        fig = Figure(figsize=(5, 5))
        canvas = FigureCanvas(fig)
        axes = fig.subplots(2, 2)
        for i in range(n):
            for j in range(n):
                axes[i, j].axis('off')
                if len(spiketrains) == 0:
                    axes[i, j].plot(0, 0)
                    continue
                cluster_label, cluster_trains = spiketrains.popitem()
                if len(cluster_trains) == 0:
                    continue

                spike_times_section = get_interval(cluster_trains,
                                                   trigger_times[0] - pre,
                                                   trigger_times[-1] + post)
                spike_times_zerocentered = []
                color = catalogueconstructor.colors[cluster_label]
                for trigger_time in trigger_times:
                    x = get_interval(spike_times_section, trigger_time - pre,
                                     trigger_time + post)
                    if len(x):
                        x -= trigger_time
                    spike_times_zerocentered.append(x)
                axes[i, j].eventplot(spike_times_zerocentered,
                                     color=color,
                                     linewidths=0.5,
                                     lineoffsets=-1)
                axes[i, j].text(0,
                                0.5,
                                cluster_label,
                                fontsize=28,
                                color=color,
                                transform=axes[i, j].transAxes)
                axes[i, j].vlines(0,
                                  *axes[i, j].get_ylim(),
                                  linewidth=1,
                                  alpha=0.9)
                axes[i, j].set_xlim(-pre, post)

        fig.subplots_adjust(wspace=0, hspace=0)
        return canvas
Ejemplo n.º 4
0
    def plot_psth(self, catalogueconstructor):
        """Plot PSTH of spiketrains."""

        us_per_tick = int(1e6 / self.dataio.sample_rate)

        spiketrains = get_spiketrains(catalogueconstructor, us_per_tick,
                                      int(self.config['start_time'] * 1e6))

        trigger_times = get_trigger_times(self.filepath, self.trigger_filename)

        # Return if there are no triggers.
        if len(trigger_times) == 0:
            print("No trigger data available; aborting PSTH plot.")
            return

        pre = int(self.plot_settings['pre'].get() * 1e6)
        post = int(self.plot_settings['post'].get() * 1e6)
        bin_method = self.plot_settings['bin_method'].get()
        num_bins = self.plot_settings['num_bins'].get() \
            if bin_method == 'manual' else bin_method

        histograms = {}
        ylim = 0
        for cluster_label, cluster_trains in spiketrains.items():
            spike_times_section = get_interval(cluster_trains,
                                               trigger_times[0] - pre,
                                               trigger_times[-1] + post)

            spike_times_zerocentered = []

            for trigger_time in trigger_times:
                t_pre = trigger_time - pre
                t_post = trigger_time + post

                x = get_interval(spike_times_section, t_pre, t_post)
                if len(x):
                    x -= trigger_time
                spike_times_zerocentered += list(x)

            cluster_counts, bin_edges = np.histogram(spike_times_zerocentered,
                                                     num_bins)
            histograms[cluster_label] = (bin_edges / 1e6, cluster_counts)
            # Update common plot range for y axis.
            max_count = np.max(cluster_counts)
            if max_count > ylim:
                ylim = max_count

        n = 2
        if len(histograms) > n * n:
            print("WARNING: Only {} out of {} available plots can be shown."
                  "".format(n * n, len(histograms)))

        fig = Figure(figsize=(5, 5))
        canvas = FigureCanvas(fig)
        axes = fig.subplots(2, 2)
        for i in range(n):
            for j in range(n):
                if len(histograms) == 0:
                    axes[i, j].plot(0, 0)
                else:
                    cluster_label, (x, y) = histograms.popitem()
                    color = catalogueconstructor.colors[cluster_label]
                    axes[i, j].fill_between(x[:-1],
                                            0,
                                            y,
                                            facecolor=color,
                                            edgecolor='k')
                    axes[i, j].text(0.7 * x[-1],
                                    0.7 * ylim,
                                    cluster_label,
                                    color=color,
                                    fontsize=28)
                    axes[i, j].vlines(0, 0, ylim)
                axes[i, j].axis('off')
                axes[i, j].set_ylim(0, ylim)
                axes[i, j].set_xlim(-pre / 1e6, post / 1e6)

        fig.subplots_adjust(wspace=0, hspace=0)
        return canvas
Ejemplo n.º 5
0
    def initialize_plot(self):

        # Return if there are no triggers.
        if len(self.trigger_times) == 0:
            return

        us_per_tick = int(1e6 / self.catalogueconstructor.dataio.sample_rate)

        start = int(self.catalogueconstructor.cbnu.config['start_time'] * 1e6)
        spiketrains = get_spiketrains(self.catalogueconstructor, us_per_tick,
                                      start)

        pre = int(self.params['pre'] * 1e6)
        post = int(self.params['post'] * 1e6)
        bin_method = self.params['bin_method']
        num_bins = self.params['num_bins'] if bin_method == 'manual' \
            else bin_method

        histograms = {}
        ylim = 0
        for cluster_label, cluster_trains in spiketrains.items():
            spike_times_section = get_interval(cluster_trains,
                                               self.trigger_times[0] - pre,
                                               self.trigger_times[-1] + post)

            spike_times_zerocentered = []

            for trigger_time in self.trigger_times:
                t_pre = trigger_time - pre
                t_post = trigger_time + post

                x = get_interval(spike_times_section, t_pre, t_post)
                if len(x):
                    x -= trigger_time
                spike_times_zerocentered += list(x)

            cluster_counts, bin_edges = np.histogram(spike_times_zerocentered,
                                                     num_bins)
            histograms[cluster_label] = (bin_edges / 1e6, cluster_counts)
            # Update common plot range for y axis.
            max_count = np.max(cluster_counts)
            if max_count > ylim:
                ylim = max_count

        n = 2
        if len(histograms) > n * n:
            print("WARNING: Only {} out of {} available PSTH plots can be "
                  "shown.".format(n * n, len(histograms)))

        viewboxes = []
        for i in range(n):
            for j in range(n):
                if len(histograms) == 0:
                    return
                viewboxes.append(MyViewBox())
                plt = self.canvas.addPlot(row=i, col=j, viewBox=viewboxes[-1])
                cluster_label, (x, y) = histograms.popitem()
                color = self.controller.qcolors.get(cluster_label,
                                                    QT.QColor('white'))
                plt.plot(x, y, stepMode=True, fillLevel=0, brush=color)
                txt = pg.TextItem(str(cluster_label), color)
                txt.setPos(0, ylim)
                plt.addItem(txt)
                plt.setYRange(0, ylim)
                plt.setXRange(-pre / 1e6, post / 1e6)

                viewboxes[-1].doubleclicked.connect(self.open_settings)