Пример #1
0
def raster(trains, time_unit=pq.ms, show_lines=True, events=None, epochs=None):
    """ Create a new plotting window with a rasterplot of spiketrains.

        :param dict trains: Dictionary of spike trains indexed by a
            Neo object (Unit or Segment).
        :param Quantity time_unit: Unit of X-Axis.
        :param bool show_lines: Determines if a horizontal line will be shown
            for each spike train.
        :param sequence events: A sequence of neo `Event` objects that will
            be marked on the plot.

    """
    if not trains:
        raise SpykeException('No spike trains for rasterplot')

    if not time_unit:
        time_unit = pq.ms

    win_title = 'Spike Trains'
    win = PlotDialog(toolbar=True, wintitle=win_title, major_grid=False)

    pW = BaseCurveWidget(win)
    plot = pW.plot

    if events is None:
        events = []
    if epochs is None:
        epochs = []

    offset = len(trains)
    legend_items = []
    for u, t in trains.iteritems():
        color = helper.get_object_color(u)

        train = helper.add_spikes(plot, t, color, 2, 21, offset, u.name,
                                  time_unit)

        if u.name:
            legend_items.append(train)
        if show_lines:
            plot.add_item(
                make.curve([
                    t.t_start.rescale(time_unit),
                    t.t_stop.rescale(time_unit)
                ], [offset, offset],
                           color='k'))
        offset -= 1

    helper.add_epochs(plot, epochs, time_unit)
    helper.add_events(plot, events, time_unit)

    plot.set_axis_title(BasePlot.X_BOTTOM, 'Time')
    plot.set_axis_unit(BasePlot.X_BOTTOM, time_unit.dimensionality.string)

    win.add_plot_widget(pW, 0)

    legend = make.legend(restrict_items=legend_items)
    plot.add_item(legend)
    win.add_legend_option([legend], True)

    if len(trains) > 1:
        plot.set_axis_limits(BasePlot.Y_LEFT, 0.5, len(trains) + 0.5)

    win.add_custom_curve_tools()
    win.show()

    return win
Пример #2
0
def spikes(spikes,
           axes_style,
           strong=None,
           anti_alias=False,
           fade=1.0,
           subplot_layout=0,
           time_unit=pq.ms,
           progress=None):
    """ Create a plot dialog with spike waveforms. Assumes that all spikes
    have waveforms with the same number of channels.

    :param dict spikes: A dictionary of :class:`neo.core.Spike` lists.
    :param int axes_style: Plotting mode. The following values are possible:

        - 1: Show each channel in a seperate plot, split vertically.
        - 2: Show each channel in a separate plot, split horizontally.
        - 3: Show each key of ``spikes`` in a separate plot,
          channels are split vertically.
        - 4: Show each key of ``spikes`` in a separate plot,
          channels are split horizontally.
        - 5: Show all channels in the same plot, split vertically.
        - 6: Show all channels in the same plot, split horizontally.

    :param dict strong: A dictionary of :class:`neo.core.Spike` lists. When
        given, these spikes are shown as thick lines on top of the regular
        spikes in the respective plots.
    :param bool anti_alias: Determines whether an antialiased plot is created.
    :param float fade: Vary transparency by segment. For values > 0, the first
        spike for each unit is displayed with the corresponding alpha
        value and alpha is linearly interpolated until it is 1 for the
        last spike. For values < 0, alpha is 1 for the first spike and
        ``fade`` for the last spike. Does not affect spikes from ``strong``.
    :param bool subplot_layout: The way subplots are arranged on the window:

        - 0: Linear - horizontally or vertically,
          depending on ``axis_style``.
        - 1: Square - this layout tries to have the same number of plots per
          row and per column.

    :param Quantity time_unit: Unit of X-Axis.
    :param progress: Set this parameter to report progress.
    :type progress: :class:`spykeutils.progress_indicator.ProgressIndicator`
    """
    if (not spikes or sum((len(l) for l in spikes.itervalues())) < 1) and \
            (not strong or sum((len(l) for l in strong.itervalues())) < 1):
        raise SpykeException('No spikes for spike waveform plot!')
    if not progress:
        progress = ProgressIndicator()
    if strong is None:
        strong = {}

    progress.begin('Creating waveform plot')
    progress.set_ticks(
        sum((len(l) for l in spikes.itervalues())) +
        sum((len(l) for l in strong.itervalues())))
    win_title = 'Spike waveforms'
    win = PlotDialog(toolbar=True, wintitle=win_title)

    try:
        ref_spike = spikes[spikes.keys()[0]][0]
    except IndexError:
        ref_spike = strong[strong.keys()[0]][0]
    if ref_spike.waveform is None:
        raise SpykeException('Cannot create waveform plot: At least one spike '
                             'has no waveform or sampling rate!')
    ref_units = ref_spike.waveform.units
    channels = range(ref_spike.waveform.shape[1])

    # Keys from spikes and strong without duplicates in original order
    seen = set()
    indices = [
        k for k in spikes.keys() + strong.keys()
        if k not in seen and not seen.add(k)
    ]

    if axes_style <= 2:  # Separate channel plots
        for c in channels:
            pw = BaseCurveWidget(win)
            plot = pw.plot
            plot.set_antialiasing(anti_alias)
            for u in spikes:
                color = helper.get_object_color(u)
                qcol = Qt.QColor(color)
                alpha = fade if fade > 0.0 else 1.0
                alpha_step = 1.0 - fade if fade > 0.0 else -1.0 - fade
                alpha_step /= len(spikes[u])
                if len(spikes[u]) == 1:
                    alpha = 1.0

                for s in spikes[u]:
                    if s.waveform is None or s.sampling_rate is None:
                        raise SpykeException('Cannot create waveform plot: '
                                             'At least one spike has no '
                                             'waveform or sampling rate!')
                    x = (sp.arange(s.waveform.shape[0]) /
                         s.sampling_rate).rescale(time_unit)
                    curve = make.curve(x,
                                       s.waveform[:, c].rescale(ref_units),
                                       title=u.name,
                                       color=color)

                    qcol.setAlphaF(alpha)
                    curve.setPen(Qt.QPen(qcol))
                    alpha += alpha_step

                    plot.add_item(curve)
                    progress.step()

            for u in strong:
                color = helper.get_object_color(u)
                for s in strong[u]:
                    x = (sp.arange(s.waveform.shape[0]) /
                         s.sampling_rate).rescale(time_unit)
                    outline = make.curve(x,
                                         s.waveform[:, c].rescale(ref_units),
                                         color='#000000',
                                         linewidth=4)
                    curve = make.curve(x,
                                       s.waveform[:, c].rescale(ref_units),
                                       color=color,
                                       linewidth=2)
                    plot.add_item(outline)
                    plot.add_item(curve)
                    progress.step()

            _add_plot(plot, pw, win, c, len(channels), subplot_layout,
                      axes_style, time_unit, ref_units)

        helper.make_window_legend(win, indices, True)
    elif axes_style > 4:  # Only one plot needed
        pw = BaseCurveWidget(win)
        plot = pw.plot
        plot.set_antialiasing(anti_alias)

        if axes_style == 6:  # Horizontal split
            l = _split_plot_hor(channels, spikes, strong, fade, ref_units,
                                time_unit, progress, plot)

            plot.set_axis_title(BasePlot.X_BOTTOM, 'Time')
            plot.set_axis_unit(BasePlot.X_BOTTOM,
                               time_unit.dimensionality.string)
        else:  # Vertical split
            channels.reverse()

            max_offset = _find_y_offset(channels, spikes, strong, ref_units)
            l = _split_plot_ver(channels, spikes, strong, fade, ref_units,
                                time_unit, progress, max_offset, plot)

            plot.set_axis_title(BasePlot.Y_LEFT, 'Voltage')
            plot.set_axis_unit(BasePlot.Y_LEFT,
                               ref_units.dimensionality.string)

        win.add_plot_widget(pw, 0)
        win.add_legend_option([l], True)
    else:  # One plot per unit
        if axes_style == 3:
            channels.reverse()
            max_offset = _find_y_offset(channels, spikes, strong, ref_units)

        for i, u in enumerate(indices):
            pw = BaseCurveWidget(win)
            plot = pw.plot
            plot.set_antialiasing(anti_alias)

            spk = {}
            if u in spikes:
                spk[u] = spikes[u]
            st = {}
            if u in strong:
                st[u] = strong[u]

            if axes_style == 3:  # Vertical split
                _split_plot_ver(channels, spk, st, fade, ref_units, time_unit,
                                progress, max_offset, plot)
            else:  # Horizontal split
                _split_plot_hor(channels, spk, st, fade, ref_units, time_unit,
                                progress, plot)

            _add_plot(plot, pw, win, i, len(indices), subplot_layout,
                      axes_style, time_unit, ref_units)

    win.add_custom_curve_tools()
    progress.done()
    win.show()

    if axes_style <= 2:
        if len(channels) > 1:
            win.add_x_synchronization_option(True, channels)
            win.add_y_synchronization_option(True, channels)
    elif axes_style <= 4:
        if len(spikes) > 1:
            win.add_x_synchronization_option(True, range(len(spikes)))
            win.add_y_synchronization_option(True, range(len(spikes)))

    return win
Пример #3
0
def sde(trains,
        events=None,
        start=0 * pq.ms,
        stop=None,
        kernel_size=100 * pq.ms,
        optimize_steps=0,
        minimum_kernel=10 * pq.ms,
        maximum_kernel=500 * pq.ms,
        kernel=None,
        time_unit=pq.ms,
        progress=None):
    """ Create a spike density estimation plot.

    The spike density estimations give an estimate of the instantaneous
    rate. Optionally finds optimal kernel size for given data.

    :param dict trains: A dictionary of :class:`neo.core.SpikeTrain` lists.
    :param dict events: A dictionary (with the same indices as ``trains``)
        of Event objects or lists of Event objects. In case of lists,
        the first event in the list will be used for alignment. The events
        will be at time 0 on the plot. If None, spike trains are used
        unmodified.
    :param start: The desired time for the start of the first bin. It
        will be recalculated if there are spike trains which start later
        than this time. This parameter can be negative (which could be
        useful when aligning on events).
    :type start: Quantity scalar
    :param stop: The desired time for the end of the last bin. It will
        be recalculated if there are spike trains which end earlier
        than this time.
    :type stop: Quantity scalar
    :param kernel_size: A uniform kernel size for all spike trains.
        Only used if optimization of kernel sizes is not used (i.e.
        ``optimize_steps`` is 0).
    :type kernel_size: Quantity scalar
    :param int optimize_steps: The number of different kernel sizes tried
        between ``minimum_kernel`` and ``maximum_kernel``.
        If 0, ``kernel_size`` will be used.
    :param minimum_kernel: The minimum kernel size to try in optimization.
    :type minimum_kernel: Quantity scalar
    :param maximum_kernel: The maximum kernel size to try in optimization.
    :type maximum_kernel: Quantity scalar
    :param kernel: The kernel function or instance to use, should accept
        two parameters: A ndarray of distances and a kernel size.
        The total area under the kernel function should be 1.
        Automatic optimization assumes a Gaussian kernel and will
        likely not produce optimal results for different kernels.
        Default: Gaussian kernel
    :type kernel: func or :class:`spykeutils.signal_processing.Kernel`
    :param Quantity time_unit: Unit of X-Axis.
    :param progress: Set this parameter to report progress.
    :type progress: :class:`spykeutils.progress_indicator.ProgressIndicator`
    """
    if not progress:
        progress = ProgressIndicator()

    start.units = time_unit
    if stop:
        stop.units = time_unit
    kernel_size.units = time_unit
    minimum_kernel.units = time_unit
    maximum_kernel.units = time_unit

    if kernel is None:
        kernel = signal_processing.GaussianKernel(100 * pq.ms)

    # Align spike trains
    for u in trains:
        if events:
            trains[u] = rate_estimation.aligned_spike_trains(trains[u], events)

    # Calculate spike density estimation
    if optimize_steps:
        steps = sp.logspace(sp.log10(minimum_kernel), sp.log10(maximum_kernel),
                            optimize_steps) * time_unit
        sde, kernel_size, eval_points = \
            rate_estimation.spike_density_estimation(
                trains, start, stop,
                optimize_steps=steps, kernel=kernel,
                progress=progress)
    else:
        sde, kernel_size, eval_points = \
            rate_estimation.spike_density_estimation(
                trains, start, stop,
                kernel_size=kernel_size, kernel=kernel,
                progress=progress)
    progress.done()

    if not sde:
        raise SpykeException('No spike trains for SDE!')

    # Plot
    win_title = 'Kernel Density Estimation'
    win = PlotDialog(toolbar=True, wintitle=win_title)

    pW = BaseCurveWidget(win)
    plot = pW.plot
    plot.set_antialiasing(True)
    for u in trains:
        if u and u.name:
            name = u.name
        else:
            name = 'Unknown'

        curve = make.curve(
            eval_points,
            sde[u],
            title='%s, Kernel width %.2f %s' %
            (name, kernel_size[u], time_unit.dimensionality.string),
            color=helper.get_object_color(u))
        plot.add_item(curve)

    plot.set_axis_title(BasePlot.X_BOTTOM, 'Time')
    plot.set_axis_unit(BasePlot.X_BOTTOM, eval_points.dimensionality.string)
    plot.set_axis_title(BasePlot.Y_LEFT, 'Rate')
    plot.set_axis_unit(BasePlot.Y_LEFT, 'Hz')
    l = make.legend()
    plot.add_item(l)

    win.add_plot_widget(pW, 0)
    win.add_custom_curve_tools()
    win.add_legend_option([l], True)
    win.show()

    return win
Пример #4
0
def cross_correlogram(trains,
                      bin_size,
                      max_lag=500 * pq.ms,
                      border_correction=True,
                      per_second=True,
                      square=False,
                      time_unit=pq.ms,
                      progress=None):
    """ Create (cross-)correlograms from a dictionary of spike train
    lists for different units.

    :param dict trains: Dictionary of :class:`neo.core.SpikeTrain` lists.
    :param bin_size: Bin size (time).
    :type bin_size: Quantity scalar
    :param max_lag: Maximum time lag for which spikes are considered
        (end time of calculated correlogram).
    :type max_lag: Quantity scalar
    :param bool border_correction: Apply correction for less data at higher
        timelags.
    :param bool per_second: If ``True``, the y-axis is count per second,
        otherwise it is count per spike train.
    :param bool square: If ``True``, the plot will include all
        cross-correlograms, even if they are just mirrored versions of each
        other. The autocorrelograms are displayed as the diagonal of a
        square plot matrix. If ``False``, mirrored plots are omitted.
    :param Quantity time_unit: Unit of X-Axis.
    :param progress: Set this parameter to report progress.
    :type progress: :class:`spykeutils.progress_indicator.ProgressIndicator`
    """
    if not trains:
        raise SpykeException('No spike trains for correlogram')
    if not progress:
        progress = ProgressIndicator()

    win_title = 'Correlogram | Bin size ' + str(bin_size)
    progress.begin('Creating correlogram')
    progress.set_status('Calculating...')
    win = PlotDialog(toolbar=True,
                     wintitle=win_title,
                     min_plot_width=150,
                     min_plot_height=100)

    correlograms, bins = correlogram(trains, bin_size, max_lag,
                                     border_correction, per_second, time_unit,
                                     progress)
    x = bins[:-1] + bin_size / 2

    crlgs = []
    indices = correlograms.keys()

    for i1 in xrange(len(indices)):
        start_i = 0
        if not square:
            start_i = i1
        for i2 in xrange(start_i, len(indices)):
            crlgs.append((correlograms[indices[i1]][indices[i2]], indices[i1],
                          indices[i2]))

    columns = int(sp.sqrt(len(crlgs)))

    legends = []
    for i, c in enumerate(crlgs):
        legend_items = []
        pW = BaseCurveWidget(win)
        plot = pW.plot
        plot.set_antialiasing(True)
        plot.add_item(make.curve(x, c[0]))

        # Create legend
        color = helper.get_object_color(c[1])
        color_curve = make.curve([], [],
                                 c[1].name,
                                 color,
                                 'NoPen',
                                 linewidth=1,
                                 marker='Rect',
                                 markerfacecolor=color,
                                 markeredgecolor=color)
        legend_items.append(color_curve)
        plot.add_item(color_curve)
        if c[1] != c[2]:
            color = helper.get_object_color(c[2])
            color_curve = make.curve([], [],
                                     c[2].name,
                                     color,
                                     'NoPen',
                                     linewidth=1,
                                     marker='Rect',
                                     markerfacecolor=color,
                                     markeredgecolor=color)
            legend_items.append(color_curve)
            plot.add_item(color_curve)
        legends.append(make.legend(restrict_items=legend_items))
        plot.add_item(legends[-1])

        if i >= len(crlgs) - columns:
            plot.set_axis_title(BasePlot.X_BOTTOM, 'Time')
            plot.set_axis_unit(BasePlot.X_BOTTOM,
                               time_unit.dimensionality.string)
        if i % columns == 0:
            plot.set_axis_title(BasePlot.Y_LEFT, 'Correlation')
            if per_second:
                plot.set_axis_unit(BasePlot.Y_LEFT, 'count/second')
            else:
                plot.set_axis_unit(BasePlot.Y_LEFT, 'count/segment')

        win.add_plot_widget(pW, i, column=i % columns)

    win.add_custom_curve_tools()
    progress.done()
    win.add_legend_option(legends, True)
    win.show()

    if len(crlgs) > 1:
        win.add_x_synchronization_option(True, range(len(crlgs)))
        win.add_y_synchronization_option(False, range(len(crlgs)))

    return win
Пример #5
0
def signals(signals,
            events=None,
            epochs=None,
            spike_trains=None,
            spikes=None,
            show_waveforms=True,
            use_subplots=True,
            subplot_names=True,
            time_unit=pq.s,
            y_unit=None,
            progress=None):
    """ Create a plot from a list of analog signals.

    :param list signals: The list of :class:`neo.core.AnalogSignal` objects
        to plot.
    :param sequence events: A list of Event objects to be included in the
        plot.
    :param sequence epochs: A list of Epoch objects to be included in the
        plot.
    :param list spike_trains: A list of :class:`neo.core.SpikeTrain` objects
        to be included in the plot. The ``unit`` property (if it exists) is
        used for color and legend entries.
    :param list spikes: A list :class:`neo.core.Spike` objects to be included
        in the plot. The ``unit`` property (if it exists) is used for color
        and legend entries.
    :param bool show_waveforms: Determines if spikes from
        :class:`neo.core.Spike` and :class:`neo.core.SpikeTrain` objects are
        shown as waveforms (if available) or vertical lines.
    :param bool use_subplots: Determines if a separate subplot for is created
        each signal.
    :param bool subplot_names: Only valid if ``use_subplots`` is True.
        Determines if signal (or channel) names are shown for subplots.
    :param Quantity time_unit: The unit of the x axis.
    :param progress: Set this parameter to report progress.
    :type progress: :class:`spykeutils.progress_indicator.ProgressIndicator`
    """
    if not signals:
        raise SpykeException(
            'Cannot create signal plot: No signal data provided!')
    if not progress:
        progress = ProgressIndicator()

    # Plot title
    win_title = 'Analog Signal'
    if len(set((s.recordingchannel for s in signals))) == 1:
        if signals[0].recordingchannel and signals[0].recordingchannel.name:
            win_title += ' | Recording Channel: %s' %\
                         signals[0].recordingchannel.name
    if len(set((s.segment for s in signals))) == 1:
        if signals[0].segment and signals[0].segment.name:
            win_title += ' | Segment: %s' % signals[0].segment.name
    win = PlotDialog(toolbar=True, wintitle=win_title)

    if events is None:
        events = []
    if epochs is None:
        epochs = []
    if spike_trains is None:
        spike_trains = []
    if spikes is None:
        spikes = []

    if show_waveforms:
        for st in spike_trains:
            if st.waveforms is not None:
                spikes.extend(conversions.spike_train_to_spikes(st))
        spike_trains = []
    else:
        unit_spikes = {}
        for s in spikes:
            unit_spikes.setdefault(s.unit, []).append(s)
        for sps in unit_spikes.itervalues():
            spike_trains.append(conversions.spikes_to_spike_train(sps, False))
        spikes = []

    channels = range(len(signals))

    channel_indices = []
    for s in signals:
        if not s.recordingchannel:
            channel_indices.append(-1)
        else:
            channel_indices.append(s.recordingchannel.index)

    # Heuristic: If multiple channels have the same index, use channel order
    # as index for spike waveforms
    nonindices = max(0, channel_indices.count(-1) - 1)
    if len(set(channel_indices)) != len(channel_indices) - nonindices:
        channel_indices = range(len(signals))

    progress.set_ticks((len(spike_trains) + len(spikes) + 1) * len(channels))

    offset = 0 * signals[0].units
    if use_subplots:
        plot = None
        for c in channels:
            pW = BaseCurveWidget(win)
            plot = pW.plot

            if subplot_names:
                if signals[c].name:
                    win.set_plot_title(plot, signals[c].name)
                elif signals[c].recordingchannel:
                    if signals[c].recordingchannel.name:
                        win.set_plot_title(plot,
                                           signals[c].recordingchannel.name)

            sample = (1 / signals[c].sampling_rate).simplified
            x = (sp.arange(signals[c].shape[0])) * sample + signals[c].t_start
            x.units = time_unit

            helper.add_epochs(plot, epochs, x.units)
            if y_unit is not None:
                plot.add_item(make.curve(x, signals[c].rescale(y_unit)))
            else:
                plot.add_item(make.curve(x, signals[c]))
            helper.add_events(plot, events, x.units)

            _add_spike_waveforms(plot, spikes, x.units, channel_indices[c],
                                 offset, progress)

            for train in spike_trains:
                color = helper.get_object_color(train.unit)
                helper.add_spikes(plot, train, color, units=x.units)
                progress.step()

            win.add_plot_widget(pW, c)
            plot.set_axis_unit(BasePlot.Y_LEFT,
                               signals[c].dimensionality.string)
            progress.step()

        plot.set_axis_title(BasePlot.X_BOTTOM, 'Time')
        plot.set_axis_unit(BasePlot.X_BOTTOM, x.dimensionality.string)
    else:
        channels.reverse()

        pW = BaseCurveWidget(win)
        plot = pW.plot

        helper.add_epochs(plot, epochs, time_unit)

        # Find plot y offset
        max_offset = 0 * signals[0].units
        for i, c in enumerate(channels[1:], 1):
            cur_offset = signals[channels[i - 1]].max() - signals[c].min()
            if cur_offset > max_offset:
                max_offset = cur_offset

        offset -= signals[channels[0]].min()

        for c in channels:
            sample = (1 / signals[c].sampling_rate).simplified
            x = (sp.arange(signals[c].shape[0])) * sample + signals[c].t_start
            x.units = time_unit

            if y_unit is not None:
                plot.add_item(
                    make.curve(x, (signals[c] + offset).rescale(y_unit)))
            else:
                plot.add_item(make.curve(x, signals[c] + offset))
            _add_spike_waveforms(plot, spikes, x.units, channel_indices[c],
                                 offset, progress)
            offset += max_offset
            progress.step()

        helper.add_events(plot, events, x.units)

        for train in spike_trains:
            color = helper.get_object_color(train.unit)
            helper.add_spikes(plot, train, color, units=x.units)
            progress.step()

        win.add_plot_widget(pW, 0)

        plot.set_axis_title(BasePlot.X_BOTTOM, 'Time')
        plot.set_axis_unit(BasePlot.X_BOTTOM, x.dimensionality.string)
        plot.set_axis_unit(BasePlot.Y_LEFT, signals[0].dimensionality.string)

    win.add_custom_curve_tools()

    units = set([s.unit for s in spike_trains])
    units = units.union([s.unit for s in spikes])

    progress.done()

    helper.make_window_legend(win, units, False)
    win.show()

    if use_subplots:
        win.add_x_synchronization_option(True, channels)
        win.add_y_synchronization_option(False, channels)

    return win
Пример #6
0
def isi(trains, bin_size, cut_off, bar_plot=False, time_unit=pq.ms):
    """ Create a plot dialog with an interspike interval histogram.

    :param dict trains: Dictionary with lists of spike trains indexed by
        units for which to display ISI histograms
    :param bin_size: Bin size (time)
    :type bin_size: Quantity scalar
    :param cut_off: End of histogram (time)
    :type bin_size: Quantity scalar
    :param bool bar_plot: If ``True``, create a bar ISI histogram for each
        index in ``trains``. Else, create a line ISI histogram.
    :param Quantity time_unit: Unit of X-Axis.
    """
    if not trains:
        raise SpykeException('No spike trains for ISI histogram')

    win_title = 'ISI Histogram | Bin size: ' + str(bin_size)
    win = PlotDialog(toolbar=True,
                     wintitle=win_title,
                     min_plot_width=150,
                     min_plot_height=100)
    bin_size = bin_size.rescale(time_unit)
    cut_off = cut_off.rescale(time_unit)
    bins = sp.arange(0 * time_unit, cut_off, bin_size) * time_unit

    legends = []
    if bar_plot:
        ind = 0
        columns = int(sp.sqrt(len(trains)))
        for i, train_list in trains.iteritems():
            pW = BaseCurveWidget(win)
            plot = pW.plot
            intervals = []
            for t in train_list:
                t = t.rescale(time_unit)
                sTrain = sp.asarray(t)
                sTrain.sort()
                intervals.extend(sp.diff(sTrain))

            (isi, bins) = sp.histogram(intervals, bins)

            if i and hasattr(i, 'name') and i.name:
                name = i.name
            else:
                name = 'Unknown'

            show_isi = list(isi)
            show_isi.insert(0, show_isi[0])
            curve = make.curve(bins,
                               show_isi,
                               name,
                               color='k',
                               curvestyle="Steps",
                               shade=1.0)
            plot.add_item(curve)

            # Create legend
            color = helper.get_object_color(i)
            color_curve = make.curve([], [],
                                     name,
                                     color,
                                     'NoPen',
                                     linewidth=1,
                                     marker='Rect',
                                     markerfacecolor=color,
                                     markeredgecolor=color)
            plot.add_item(color_curve)
            legends.append(make.legend(restrict_items=[color_curve]))
            plot.add_item(legends[-1])

            # Prepare plot
            plot.set_antialiasing(False)
            scale = plot.axisScaleDiv(BasePlot.Y_LEFT)
            plot.setAxisScale(BasePlot.Y_LEFT, 0, scale.upperBound())
            if ind % columns == 0:
                plot.set_axis_title(BasePlot.Y_LEFT, 'Number of intervals')
            if ind >= len(trains) - columns:
                plot.set_axis_title(BasePlot.X_BOTTOM, 'Interval length')
                plot.set_axis_unit(BasePlot.X_BOTTOM,
                                   time_unit.dimensionality.string)

            win.add_plot_widget(pW, ind, column=ind % columns)
            ind += 1
    else:
        pW = BaseCurveWidget(win)
        plot = pW.plot
        legend_items = []

        for i, train_list in trains.iteritems():
            intervals = []
            for t in train_list:
                t = t.rescale(time_unit)
                sTrain = sp.asarray(t)
                sTrain.sort()
                intervals.extend(sp.diff(sTrain))

            (isi, bins) = sp.histogram(intervals, bins)

            if i and hasattr(i, 'name') and i.name:
                name = i.name
            else:
                name = 'Unknown'
            color = helper.get_object_color(i)

            curve = make.curve(bins, isi, name, color=color)
            legend_items.append(curve)
            plot.add_item(curve)

        win.add_plot_widget(pW, 0)

        legends.append(make.legend(restrict_items=legend_items))
        plot.add_item(legends[-1])

        plot.set_antialiasing(True)
        plot.set_axis_title(BasePlot.Y_LEFT, 'Number of intervals')
        plot.set_axis_title(BasePlot.X_BOTTOM, 'Interval length')
        plot.set_axis_unit(BasePlot.X_BOTTOM, time_unit.dimensionality.string)

    win.add_custom_curve_tools()
    win.add_legend_option(legends, True)
    win.show()

    if bar_plot and len(trains) > 1:
        win.add_x_synchronization_option(True, range(len(trains)))
        win.add_y_synchronization_option(False, range(len(trains)))

    return win
Пример #7
0
def psth(trains, events=None, start=0 * pq.ms, stop=None,
         bin_size=100 * pq.ms, rate_correction=True, bar_plot=False,
         time_unit=pq.ms, progress=None):
    """ Create a peri stimulus time histogram.

    The peri stimulus time histogram gives an estimate of the instantaneous
    rate.

    :param dict trains: A dictionary of :class:`neo.core.SpikeTrain` lists.
    :param dict events: A dictionary of Event objects, indexed by segment.
        The events will be at time 0 on the plot. If None, spike trains
        are used unmodified.
    :param start: The desired time for the start of the first bin. It
        will be recalculated if there are spike trains which start later
        than this time. This parameter can be negative (which could be
        useful when aligning on events).
    :type start: Quantity scalar
    :param stop: The desired time for the end of the last bin. It will
        be recalculated if there are spike trains which end earlier
        than this time.
    :type stop: Quantity scalar
    :param bin_size: The bin size for the histogram.
    :type bin_size: Quantity scalar
    :param bool rate_correction: Determines if a rates (``True``) or
        counts (``False``) are shown.
    :param bool bar_plot: Determines if a bar plot (``True``) or a line
        plot (``False``) will be created. In case of a bar plot, one plot
        for each index in ``trains`` will be created.
    :param Quantity time_unit: Unit of X-Axis.
    :param progress: Set this parameter to report progress.
    :type progress: :class:`spykeutils.progress_indicator.ProgressIndicator`
    """
    if not trains:
        raise SpykeException('No spike trains for PSTH!')
    if not progress:
        progress = ProgressIndicator()

    # Align spike trains
    for u in trains:
        if events:
            trains[u] = rate_estimation.aligned_spike_trains(
                trains[u], events)

    rates, bins = rate_estimation.psth(
        trains, bin_size, start=start, stop=stop,
        rate_correction=rate_correction)
    bins = bins.rescale(time_unit)

    if not psth:
        raise SpykeException('No spike trains for PSTH!')

    win_title = 'PSTH | Bin size %.2f %s' % (bin_size,
                                             time_unit.dimensionality.string)
    win = PlotDialog(toolbar=True, wintitle=win_title, min_plot_width=150,
                     min_plot_height=100)

    legends = []
    if bar_plot:
        ind = 0
        columns = int(sp.sqrt(len(rates)))
        for i, r in rates.iteritems():
            if i and hasattr(i, 'name') and i.name:
                name = i.name
            else:
                name = 'Unknown'

            pW = BaseCurveWidget(win)
            plot = pW.plot

            show_rates = list(r)
            show_rates.insert(0, show_rates[0])
            curve = make.curve(
                bins, show_rates, name, color='k',
                curvestyle="Steps", shade=1.0)
            plot.add_item(curve)

            # Create legend
            color = helper.get_object_color(i)
            color_curve = make.curve(
                [], [], name, color, 'NoPen', linewidth=1, marker='Rect',
                markerfacecolor=color, markeredgecolor=color)
            plot.add_item(color_curve)
            legends.append(make.legend(restrict_items=[color_curve]))
            plot.add_item(legends[-1])

            # Prepare plot
            plot.set_antialiasing(False)
            scale = plot.axisScaleDiv(BasePlot.Y_LEFT)
            plot.setAxisScale(BasePlot.Y_LEFT, 0, scale.upperBound())
            if ind % columns == 0:
                if not rate_correction:
                    plot.set_axis_title(BasePlot.Y_LEFT, 'Spike Count')
                else:
                    plot.set_axis_title(BasePlot.Y_LEFT, 'Rate')
                    plot.set_axis_unit(BasePlot.Y_LEFT, 'Hz')
            if ind >= len(trains) - columns:
                plot.set_axis_title(BasePlot.X_BOTTOM, 'Time')
                plot.set_axis_unit(
                    BasePlot.X_BOTTOM, time_unit.dimensionality.string)

            win.add_plot_widget(pW, ind, column=ind % columns)
            ind += 1
    else:
        bins = 0.5 * sp.diff(bins) + bins[:-1]

        pW = BaseCurveWidget(win)
        plot = pW.plot
        legend_items = []

        for i, r in rates.iteritems():
            if i and hasattr(i, 'name') and i.name:
                name = i.name
            else:
                name = 'Unknown'

            curve = make.curve(
                bins, r, name,
                color=helper.get_object_color(i))
            legend_items.append(curve)
            plot.add_item(curve)

        win.add_plot_widget(pW, 0)

        legends.append(make.legend(restrict_items=legend_items))
        plot.add_item(legends[-1])

        if not rate_correction:
            plot.set_axis_title(BasePlot.Y_LEFT, 'Spike Count')
        else:
            plot.set_axis_title(BasePlot.Y_LEFT, 'Rate')
            plot.set_axis_unit(BasePlot.Y_LEFT, 'Hz')
        plot.set_axis_title(BasePlot.X_BOTTOM, 'Time')
        plot.set_axis_unit(BasePlot.X_BOTTOM, time_unit.dimensionality.string)
        plot.set_antialiasing(True)

    win.add_custom_curve_tools()
    win.add_legend_option(legends, True)
    progress.done()
    win.show()

    if bar_plot and len(rates) > 1:
        win.add_x_synchronization_option(True, range(len(rates)))
        win.add_y_synchronization_option(False, range(len(rates)))

    return win