Пример #1
0
def get_peaks(_spike_times, _trigger_times, path, _delay, _cell_name,
              save_plot, _threshold, _num_bins, _pre, _post):

    spike_times_section = get_interval(_spike_times, _trigger_times[0] - _pre,
                                       _trigger_times[-1] + _post)

    spike_times_zerocentered = []

    for trigger_time in _trigger_times:
        t_pre = trigger_time - _pre
        t_post = trigger_time + _post

        x = get_interval(spike_times_section, t_pre, t_post)
        if len(x):
            x -= trigger_time  # Zero-center
            x *= 1e3  # Seconds to ms
        spike_times_zerocentered.append(x)

    sns_fig = sns.distplot(np.concatenate(spike_times_zerocentered),
                           _num_bins, hist=True, rug=True, kde=True,
                           hist_kws={'align': 'left'})
    bin_edges, counts = sns_fig.get_lines()[0].get_data()

    sns_fig.set_xlabel("Time [ms]")

    median = np.median(counts)
    mad = np.median(np.abs(counts - median))
    min_height = median + _threshold * mad
    # mean = np.mean(counts)
    # std = np.std(counts)
    # min_height = mean + _threshold * std
    peak_idxs, _ = find_peaks(counts, min_height)
    peak_heights = counts[peak_idxs]
    sort_idxs = np.argsort(peak_heights)
    max_peak_idxs = peak_idxs[sort_idxs][-2:]

    ymax = 0.1  # axes.get_ylim()[1]
    peak_times = []
    if len(max_peak_idxs) > 0:
        peak_time = bin_edges[max_peak_idxs[0]]
        sns_fig.vlines(peak_time, 0, ymax, color='g')
        peak_times.append(peak_time)
    if len(max_peak_idxs) > 1:
        peak_time = bin_edges[max_peak_idxs[1]]
        sns_fig.vlines(peak_time, 0, ymax, color='b')
        peak_times.append(peak_time)

    if save_plot:
        pre_ms = 1e3 * _pre
        post_ms = 1e3 * _post
        filepath = os.path.join(path,
                                'PSTH_{}_{}.png'.format(_cell_name, _delay))
        sns_fig.set_xlim(-pre_ms, post_ms)
        sns_fig.vlines(0, 0, ymax, color='r')
        sns_fig.hlines(min_height, -pre_ms, post_ms, color='y')
        sns_fig.get_figure().savefig(os.path.join(filepath))
        plt.clf()

    return peak_times
Пример #2
0
def get_spiketimes_zerocentered(_spike_times, _trigger_times):
    spike_times_section = get_interval(_spike_times, _trigger_times[0] - PRE,
                                       _trigger_times[-1] + POST)

    spike_times_zerocentered = []

    for trigger_time in _trigger_times:
        t_pre = trigger_time - PRE
        t_post = trigger_time + POST

        x = get_interval(spike_times_section, t_pre, t_post)
        if len(x):
            x -= trigger_time  # Zero-center
            x *= 1e3  # Seconds to ms
        spike_times_zerocentered.append(x)

    return np.concatenate(spike_times_zerocentered)
Пример #3
0
def run_single(path, save_plots, _threshold, _num_bins):
    spike_times_population = [[] for _ in range(num_trials)]
    for _spike_times, _trigger_times in zip(spike_times_list,
                                            trigger_times_list):
        for cell_spikes in _spike_times.values():
            for t, trigger_time in enumerate(_trigger_times):
                x = get_interval(cell_spikes, trigger_time - pre,
                                 trigger_time + post)
                if len(x):
                    x -= trigger_time  # Zero-center
                    x *= 1e3  # Seconds to ms

                spike_times_population[t] += list(x)

    _peaks = []
    for trigger_idx, _spike_times in enumerate(spike_times_population):
        figure = Figure()
        canvas = FigureCanvas(figure)
        axes = figure.subplots(1, 1)
        axes.set_xlabel("Time [ms]")
        counts, bin_edges, _ = axes.hist(_spike_times,
                                         _num_bins,
                                         align='left',
                                         histtype='stepfilled',
                                         facecolor='k')

        # median = np.median(counts)
        # mad = np.median(np.abs(counts - median))
        # min_height = median + 5 * mad
        mean = np.mean(counts)
        std = np.std(counts)
        min_height = mean + _threshold * std
        peak_idxs, _ = find_peaks(counts, min_height)

        if len(peak_idxs) == 0:
            _peaks.append(-1)
            continue

        peak_heights = counts[peak_idxs]
        max_peak_idx = peak_idxs[np.argmax(peak_heights)]
        peak_time = bin_edges[max_peak_idx]
        # Convert peak_time from ms to s.
        _peaks.append(peak_time / 1e3 + relative_trigger_times[trigger_idx])

        if save_plots:
            ymax = axes.get_ylim()[1]
            axes.vlines(peak_time, 0, ymax, color='g')
            pre_ms = 1e3 * pre
            post_ms = 1e3 * post
            axes.set_xlim(-pre_ms, post_ms)
            axes.vlines(0, 0, ymax, color='r')
            axes.hlines(min_height, -pre_ms, post_ms, color='y')
            figure.subplots_adjust(wspace=0, hspace=0)
            filepath = os.path.join(path, 'PSTH_{}.png'.format(trigger_idx))
            canvas.print_figure(filepath, bbox_inches='tight', dpi=100)

    return _peaks
Пример #4
0
    def plot_raster(self, catalogueconstructor):

        us_per_tick = int(1e6 / self.dataio.sample_rate)

        spiketrains = get_spiketrains(catalogueconstructor, us_per_tick,
                                      int(self.config['start_time'] * 1e6))

        trigger_times = get_trigger_times(self.filepath, self.trigger_filename)

        num_clusters = len(spiketrains)

        num_triggers = len(trigger_times)

        if num_triggers == 0:
            duration = self.config['duration']
            num_triggers = int(np.sqrt(duration))
            trigger_times = np.linspace(0,
                                        duration,
                                        num_triggers,
                                        endpoint=False,
                                        dtype=int)
            trigger_times *= int(1e6)  # Seconds to microseconds.

        pre = int(self.plot_settings['pre'].get() * 1e6)
        post = int(self.plot_settings['post'].get() * 1e6)

        n = 2
        if num_clusters > n * n:
            print("WARNING: Only {} out of {} available plots can be shown."
                  "".format(n * n, num_clusters))

        fig = Figure(figsize=(5, 5))
        canvas = FigureCanvas(fig)
        axes = fig.subplots(2, 2)
        for i in range(n):
            for j in range(n):
                axes[i, j].axis('off')
                if len(spiketrains) == 0:
                    axes[i, j].plot(0, 0)
                    continue
                cluster_label, cluster_trains = spiketrains.popitem()
                if len(cluster_trains) == 0:
                    continue

                spike_times_section = get_interval(cluster_trains,
                                                   trigger_times[0] - pre,
                                                   trigger_times[-1] + post)
                spike_times_zerocentered = []
                color = catalogueconstructor.colors[cluster_label]
                for trigger_time in trigger_times:
                    x = get_interval(spike_times_section, trigger_time - pre,
                                     trigger_time + post)
                    if len(x):
                        x -= trigger_time
                    spike_times_zerocentered.append(x)
                axes[i, j].eventplot(spike_times_zerocentered,
                                     color=color,
                                     linewidths=0.5,
                                     lineoffsets=-1)
                axes[i, j].text(0,
                                0.5,
                                cluster_label,
                                fontsize=28,
                                color=color,
                                transform=axes[i, j].transAxes)
                axes[i, j].vlines(0,
                                  *axes[i, j].get_ylim(),
                                  linewidth=1,
                                  alpha=0.9)
                axes[i, j].set_xlim(-pre, post)

        fig.subplots_adjust(wspace=0, hspace=0)
        return canvas
Пример #5
0
    def plot_psth(self, catalogueconstructor):
        """Plot PSTH of spiketrains."""

        us_per_tick = int(1e6 / self.dataio.sample_rate)

        spiketrains = get_spiketrains(catalogueconstructor, us_per_tick,
                                      int(self.config['start_time'] * 1e6))

        trigger_times = get_trigger_times(self.filepath, self.trigger_filename)

        # Return if there are no triggers.
        if len(trigger_times) == 0:
            print("No trigger data available; aborting PSTH plot.")
            return

        pre = int(self.plot_settings['pre'].get() * 1e6)
        post = int(self.plot_settings['post'].get() * 1e6)
        bin_method = self.plot_settings['bin_method'].get()
        num_bins = self.plot_settings['num_bins'].get() \
            if bin_method == 'manual' else bin_method

        histograms = {}
        ylim = 0
        for cluster_label, cluster_trains in spiketrains.items():
            spike_times_section = get_interval(cluster_trains,
                                               trigger_times[0] - pre,
                                               trigger_times[-1] + post)

            spike_times_zerocentered = []

            for trigger_time in trigger_times:
                t_pre = trigger_time - pre
                t_post = trigger_time + post

                x = get_interval(spike_times_section, t_pre, t_post)
                if len(x):
                    x -= trigger_time
                spike_times_zerocentered += list(x)

            cluster_counts, bin_edges = np.histogram(spike_times_zerocentered,
                                                     num_bins)
            histograms[cluster_label] = (bin_edges / 1e6, cluster_counts)
            # Update common plot range for y axis.
            max_count = np.max(cluster_counts)
            if max_count > ylim:
                ylim = max_count

        n = 2
        if len(histograms) > n * n:
            print("WARNING: Only {} out of {} available plots can be shown."
                  "".format(n * n, len(histograms)))

        fig = Figure(figsize=(5, 5))
        canvas = FigureCanvas(fig)
        axes = fig.subplots(2, 2)
        for i in range(n):
            for j in range(n):
                if len(histograms) == 0:
                    axes[i, j].plot(0, 0)
                else:
                    cluster_label, (x, y) = histograms.popitem()
                    color = catalogueconstructor.colors[cluster_label]
                    axes[i, j].fill_between(x[:-1],
                                            0,
                                            y,
                                            facecolor=color,
                                            edgecolor='k')
                    axes[i, j].text(0.7 * x[-1],
                                    0.7 * ylim,
                                    cluster_label,
                                    color=color,
                                    fontsize=28)
                    axes[i, j].vlines(0, 0, ylim)
                axes[i, j].axis('off')
                axes[i, j].set_ylim(0, ylim)
                axes[i, j].set_xlim(-pre / 1e6, post / 1e6)

        fig.subplots_adjust(wspace=0, hspace=0)
        return canvas
Пример #6
0
    def load(self):
        from McsPy.McsData import RawData

        start_time = int(self.gui.config['start_time'] * 1e6)
        stop_time = int(self.gui.config['stop_time'] * 1e6)

        sample_rates = []
        dtypes = []
        num_channels = []
        channel_names = []
        for filename in self.filenames:
            self.log("Loading .h5 file: {}".format(filename))
            data = RawData(filename)
            assert len(data.recordings) == 1, \
                "Can only handle a single recording per file."

            electrode_data = None
            stream_id = None
            analog_streams = data.recordings[0].analog_streams
            for stream_id, stream in analog_streams.items():
                if stream.data_subtype == 'Electrode':
                    electrode_data = stream
                    break
            assert electrode_data is not None, "Electrode data not found."

            traces, sample_rate = get_all_channel_data(electrode_data)
            us_per_tick = int(1e6 / sample_rate)
            start_tick = start_time // us_per_tick
            stop_tick = stop_time // us_per_tick
            full_duration = len(traces)
            if stop_tick >= full_duration:
                stop_tick = full_duration
            self.array_sources.append(traces[start_tick:stop_tick])
            sample_rates.append(sample_rate)
            dtypes.append(traces.dtype)
            num_channels.append(traces.shape[1])
            channel_names_of_file = get_channel_names(electrode_data)
            channel_names.append(channel_names_of_file)

            trigger_data = []
            trigger_times = []
            event_streams = data.recordings[0].event_streams
            # First, try loading trigger data from digital event stream.
            if event_streams is not None:
                for event_stream in event_streams.values():
                    for d in event_stream.event_entity.values():
                        if d.info.label == 'Digital Event Detector Event' or \
                                'Single Pulse Start' in d.info.label:
                            tr_times = d.data[0]
                            trigger_ticks = tr_times // us_per_tick
                            tr_data = np.zeros(full_duration)
                            tr_data[trigger_ticks] = 1
                            trigger_data.append(tr_data[start_tick:stop_tick])
                            trigger_times.append(
                                get_interval(tr_times, start_time, stop_time))

            # If triggers not stored as digital events, try analog stream.
            if len(trigger_times) == 0:
                analog_stream_id = (stream_id + 1) % 2
                tr_data = analog_streams[analog_stream_id].channel_data[0]
                trigger_ticks = np.flatnonzero(
                    np.diff(tr_data) > np.abs(np.min(tr_data)))
                tr_times = trigger_ticks * us_per_tick
                trigger_times.append(
                    get_interval(tr_times, start_time, stop_time))
                trigger_data.append(tr_data[start_tick:stop_tick])

            # If no triggers are available (e.g. spontaneous activity), create
            # null array.
            if len(trigger_times) == 0:
                trigger_times.append(np.array([]))
                trigger_data.append(np.zeros(stop_tick - start_tick))

            # Save stimulus as compressed numpy file for later use in GUI.
            for i in range(len(trigger_data)):
                dirname, basename = os.path.split(filename)
                basename, _ = os.path.splitext(basename)
                np.savez_compressed(os.path.join(
                    dirname, '{}_stimulus{}'.format(basename, i)),
                                    times=trigger_times[i],
                                    data=trigger_data[i])
                # Save another copy as text file for easier access in matlab.
                np.savetxt(os.path.join(
                    dirname, '{}_trigger_times{}.txt'.format(basename, i)),
                           trigger_times[i],
                           fmt='%d')
                np.savetxt(os.path.join(
                    dirname, '{}_trigger_data{}.txt'.format(basename, i)),
                           trigger_data[i],
                           fmt='%d')

        # Make sure that every file uses the same sample rate, dtype, etc.
        assert np.array_equiv(sample_rates, sample_rates[0]), \
            "Recording contains different sample rates."

        assert np.array_equiv(dtypes, dtypes[0]), \
            "Recording contains different dtypes."

        assert np.array_equiv(num_channels, num_channels[0]), \
            "Recording contains different number of channels."

        assert np.array_equiv(channel_names, channel_names[0]), \
            "Recording contains different channel names."

        self.total_channel = num_channels[0]
        self.sample_rate = sample_rates[0]
        self.dtype = dtypes[0]
        self.channel_names = channel_names[0]

        self.log("Finished initializing DataSource.")
Пример #7
0
    def initialize_plot(self):

        # Return if there are no triggers.
        if len(self.trigger_times) == 0:
            return

        us_per_tick = int(1e6 / self.catalogueconstructor.dataio.sample_rate)

        start = int(self.catalogueconstructor.cbnu.config['start_time'] * 1e6)
        spiketrains = get_spiketrains(self.catalogueconstructor, us_per_tick,
                                      start)

        pre = int(self.params['pre'] * 1e6)
        post = int(self.params['post'] * 1e6)
        bin_method = self.params['bin_method']
        num_bins = self.params['num_bins'] if bin_method == 'manual' \
            else bin_method

        histograms = {}
        ylim = 0
        for cluster_label, cluster_trains in spiketrains.items():
            spike_times_section = get_interval(cluster_trains,
                                               self.trigger_times[0] - pre,
                                               self.trigger_times[-1] + post)

            spike_times_zerocentered = []

            for trigger_time in self.trigger_times:
                t_pre = trigger_time - pre
                t_post = trigger_time + post

                x = get_interval(spike_times_section, t_pre, t_post)
                if len(x):
                    x -= trigger_time
                spike_times_zerocentered += list(x)

            cluster_counts, bin_edges = np.histogram(spike_times_zerocentered,
                                                     num_bins)
            histograms[cluster_label] = (bin_edges / 1e6, cluster_counts)
            # Update common plot range for y axis.
            max_count = np.max(cluster_counts)
            if max_count > ylim:
                ylim = max_count

        n = 2
        if len(histograms) > n * n:
            print("WARNING: Only {} out of {} available PSTH plots can be "
                  "shown.".format(n * n, len(histograms)))

        viewboxes = []
        for i in range(n):
            for j in range(n):
                if len(histograms) == 0:
                    return
                viewboxes.append(MyViewBox())
                plt = self.canvas.addPlot(row=i, col=j, viewBox=viewboxes[-1])
                cluster_label, (x, y) = histograms.popitem()
                color = self.controller.qcolors.get(cluster_label,
                                                    QT.QColor('white'))
                plt.plot(x, y, stepMode=True, fillLevel=0, brush=color)
                txt = pg.TextItem(str(cluster_label), color)
                txt.setPos(0, ylim)
                plt.addItem(txt)
                plt.setYRange(0, ylim)
                plt.setXRange(-pre / 1e6, post / 1e6)

                viewboxes[-1].doubleclicked.connect(self.open_settings)