コード例 #1
0
class ISIPlugin(analysis_plugin.AnalysisPlugin):
    bin_size = gui_data.FloatItem('Bin size', 1.0, 0.1, 10000.0, unit='ms')
    cut_off = gui_data.FloatItem('Cut off', 50.0, 2.0, 10000.0, unit='ms')
    diagram_type = gui_data.ChoiceItem('Type', ('Bar', 'Line'))
    data_source = gui_data.ChoiceItem('Data source', ('Units', 'Selections'))

    def get_name(self):
        return 'Interspike Interval Histogram'

    def start(self, current, selections):
        current.progress.begin('Creating Interspike Interval Histogram')
        if self.data_source == 0:
            d = current.spike_trains_by_unit()
        else:
            # Prepare dictionary for isi():
            # One entry of spike trains for each selection
            d = {}
            for s in selections:
                d[neo.Unit(s.name)] = s.spike_trains()
        current.progress.done()
        plot.isi(d,
                 self.bin_size * pq.ms,
                 self.cut_off * pq.ms,
                 self.diagram_type == 0,
                 time_unit=pq.ms)
コード例 #2
0
class CorrelogramPlugin(analysis_plugin.AnalysisPlugin):
    bin_size = gui_data.FloatItem('Bin size', 1.0, 0.001, 10000.0, unit='ms')
    cut_off = gui_data.FloatItem('Cut off', 50.0, 2.0, 10000.0, unit='ms')
    data_source = gui_data.ChoiceItem('Data source', ('Units', 'Selections'))
    count_per = gui_data.ChoiceItem('Counts per', ('Second', 'Segment'))
    border_correction = gui_data.BoolItem('Border correction', default=True)
    square = gui_data.BoolItem('Include mirrored plots')

    def get_name(self):
        return 'Correlogram'

    def start(self, current, selections):
        current.progress.begin('Creating correlogram')
        if self.data_source == 0:
            d = current.spike_trains_by_unit()
        else:
            # Prepare dictionary for cross_correlogram():
            # One entry of spike trains for each selection
            d = {}
            for s in selections:
                d[neo.Unit(s.name)] = s.spike_trains()

        plot.cross_correlogram(d,
                               self.bin_size * pq.ms,
                               self.cut_off * pq.ms,
                               border_correction=self.border_correction,
                               per_second=self.count_per == 0,
                               square=self.square,
                               progress=current.progress)
コード例 #3
0
ファイル: raster_plot.py プロジェクト: stjordanis/spykeviewer
class RasterPlotPlugin(analysis_plugin.AnalysisPlugin):
    domain = gui_data.ChoiceItem('Domain', ('Units', 'Segments'))
    show_lines = gui_data.BoolItem('Show lines', default=True)
    show_events = gui_data.BoolItem('Show events', default=True)
    show_epochs = gui_data.BoolItem('Show epochs', default=True)

    def get_name(self):
        return 'Raster Plot'

    def start(self, current, selections):
        current.progress.begin('Creating raster plot')

        if self.domain == 0:  # Units
            d = current.spike_trains_by_unit()
        else:  # Segments
            d = current.spike_trains_by_segment()

        # Only show first spike train for each index
        for k in d.keys():
            if d[k]:
                d[k] = d[k][0]
            else:
                d.pop(k)

        events = None
        if self.show_events:
            if self.domain == 0:  # Only events for displayed segment
                ev = current.events()
                if ev:
                    events = ev.values()[0]
            else:  # Events for all segments
                events = [
                    e for seg_events in current.events().values()
                    for e in seg_events
                ]

        epochs = None
        if self.show_epochs:
            if self.domain == 0:  # Only epochs for displayed segment
                ep = current.epochs()
                if ep:
                    epochs = ep.values()[0]
            else:  # Epochs for all segments
                epochs = [
                    e for seg_epochs in current.epochs().values()
                    for e in seg_epochs
                ]

        current.progress.done()
        plot.raster(d, pq.ms, self.show_lines, events, epochs)
コード例 #4
0
class SDEPlugin(analysis_plugin.AnalysisPlugin):
    # Configurable parameters
    kernel_size = gui_data.FloatItem('Kernel size', min=1.0, default=300.0,
        unit='ms')
    start_time = gui_data.FloatItem('Start time', default=0.0, unit='ms')

    stop_enabled = gui_data.BoolItem('Stop time enabled',
        default=False).set_prop('display', store=stop_prop)
    stop = gui_data.FloatItem('Time', default=10001.0,
        unit='ms').set_prop('display', active=stop_prop)

    align_enabled = gui_data.BoolItem('Alignment event enabled',
        default=False).set_prop('display', store=align_prop)
    align = gui_data.StringItem(
        'Event label').set_prop('display', active=align_prop)
    data_source = gui_data.ChoiceItem('Data source', ('Units', 'Selections'))

    _g = gui_data.BeginGroup('Kernel width optimization')
    optimize_enabled = gui_data.BoolItem('Enabled',
        default=False).set_prop('display', store=optimize_prop)
    minimum_kernel = gui_data.FloatItem('Minimum kernel size', default=10.0,
        unit='ms', min=0.5).set_prop('display', active=optimize_prop)
    maximum_kernel = gui_data.FloatItem('Maximum kernel size', default=1000.0,
        unit='ms', min=1.0).set_prop('display', active=optimize_prop)
    optimize_steps = gui_data.IntItem('Kernel size steps', default=30,
        min=2).set_prop('display', active=optimize_prop)
    _g_ = gui_data.EndGroup('Kernel size optimization')
    
    def __init__(self):
        super(SDEPlugin, self).__init__()
        self.unit = pq.ms

    def get_name(self):
        return 'Spike Density Estimation'

    def start(self, current, selections):
        current.progress.begin('Creating spike density estimation')

        # Prepare quantities
        start = float(self.start_time) * self.unit
        stop = None
        if self.stop_enabled:
            stop = float(self.stop) * self.unit
        kernel_size = float(self.kernel_size) * self.unit
        optimize_steps = 0
        if self.optimize_enabled:
            optimize_steps = self.optimize_steps
        minimum_kernel = self.minimum_kernel * self.unit
        maximum_kernel = self.maximum_kernel * self.unit

        # Load data
        events = None
        if self.data_source == 0:
            trains = current.spike_trains_by_unit()
            if self.align_enabled:
                events = current.labeled_events(self.align) 
        else:
            # Prepare dictionaries for psth():
            # One entry of spike trains for each selection,
            # an event for each segment occuring in any selection
            trains = {}
            if self.align_enabled:
                events = {}
            for s in selections:
                trains[neo.Unit(s.name)] = s.spike_trains()
                if self.align_enabled:
                    events.update(s.labeled_events(self.align))
                    
        if events:
            for s in events:  # Align on first event in each segment
                events[s] = events[s][0]

        plot.sde(
            trains, events, start, stop, kernel_size, optimize_steps,
            minimum_kernel, maximum_kernel, None, self.unit, current.progress)

    def configure(self):
        super(SDEPlugin, self).configure()
        while self.optimize_enabled and \
                self.maximum_kernel <= self.minimum_kernel:
            QMessageBox.warning(
                None, 'Unable to set parameters',
                'Maximum kernel size needs to be larger than '
                'minimum kernel size!')
            super(SDEPlugin, self).configure()
コード例 #5
0
class SpectrogramPlugin(analysis_plugin.AnalysisPlugin):
    nfft_index = (32, 64, 128, 256, 512, 1024, 2048, 4096, 8192)
    nfft_names = (str(s) for s in nfft_index)
    interpolate = gui_data.BoolItem('Interpolate', default=True)
    show_color_bar = gui_data.BoolItem('Show color bar', default=False)
    fft_samples = gui_data.ChoiceItem('FFT samples', nfft_names, default=3)
    which_signals = gui_data.ChoiceItem(
        'Included signals', ('AnalogSignal', 'AnalogSignalArray', 'Both'),
        default=2)

    def get_name(self):
        return 'Signal Spectogram'

    @helper.needs_qt
    def start(self, current, selections):
        current.progress.begin('Creating Spectogram')
        signals = current.analog_signals(self.which_signals + 1)
        if not signals:
            current.progress.done()
            raise SpykeException('No signals selected!')

        num_signals = len(signals)

        columns = int(round(sp.sqrt(num_signals)))

        current.progress.set_ticks(num_signals)
        samples = self.nfft_index[self.fft_samples]
        win = PlotDialog(toolbar=True,
                         wintitle="Signal Spectogram (FFT window size %d)" %
                         samples)

        for c in xrange(num_signals):
            pW = BaseImageWidget(win, yreverse=False, lock_aspect_ratio=False)
            plot = pW.plot

            s = signals[c]

            # Calculate spectrogram and create plot
            v = mlab.specgram(s,
                              NFFT=samples,
                              noverlap=samples / 2,
                              Fs=s.sampling_rate.rescale(pq.Hz))
            interpolation = 'nearest'
            if self.interpolate:
                interpolation = 'linear'
            img = make.image(sp.log(v[0]),
                             ydata=[v[1][0], v[1][-1]],
                             xdata=[
                                 v[2][0] + s.t_start.rescale(pq.s),
                                 v[2][-1] + s.t_start.rescale(pq.s)
                             ],
                             interpolation=interpolation)
            plot.add_item(img)

            # Labels etc.
            if not self.show_color_bar:
                plot.disable_unused_axes()
            title = ''
            if s.recordingchannel and s.recordingchannel.name:
                title = s.recordingchannel.name
            if s.segment and s.segment.name:
                if title:
                    title += ' , '
                title += s.segment.name
            plot.set_title(title)
            plot.set_axis_title(plot.Y_LEFT, 'Frequency')
            plot.set_axis_unit(plot.Y_LEFT,
                               s.sampling_rate.dimensionality.string)
            plot.set_axis_title(plot.X_BOTTOM, 'Time')
            time_unit = (1 / s.sampling_rate).simplified
            plot.set_axis_unit(plot.X_BOTTOM, time_unit.dimensionality.string)
            win.add_plot_widget(pW, c, column=c % columns)
            current.progress.step()

        current.progress.done()
        win.add_custom_image_tools()
        win.add_x_synchronization_option(True, range(num_signals))
        win.add_y_synchronization_option(True, range(num_signals))
        win.show()
コード例 #6
0
class PSTHPlugin(analysis_plugin.AnalysisPlugin):
    # Configurable parameters
    bin_size = gui_data.FloatItem('Bin size',
                                  min=1.0,
                                  default=500.0,
                                  unit='ms')
    start_time = gui_data.FloatItem('Start time', default=0.0, unit='ms')

    stop_enabled = gui_data.BoolItem('Stop time enabled',
                                     default=False).set_prop('display',
                                                             store=stop_prop)
    stop = gui_data.FloatItem('Time', default=10001.0,
                              unit='ms').set_prop('display', active=stop_prop)

    align_enabled = gui_data.BoolItem('Alignment event enabled',
                                      default=False).set_prop('display',
                                                              store=align_prop)
    align = gui_data.StringItem('Event label').set_prop('display',
                                                        active=align_prop)
    diagram_type = gui_data.ChoiceItem('Type', ('Bar', 'Line'))
    data_source = gui_data.ChoiceItem('Data source', ('Units', 'Selections'))

    def get_name(self):
        return 'Peristimulus Time Histogram'

    def start(self, current, selections):
        # Prepare quantities
        start = float(self.start_time) * pq.ms
        stop = None
        if self.stop_enabled:
            stop = float(self.stop) * pq.ms
        bin_size = float(self.bin_size) * pq.ms

        # Load data
        current.progress.begin('Creating PSTH')
        events = None
        if self.data_source == 0:
            trains = current.spike_trains_by_unit()
            if self.align_enabled:
                events = current.labeled_events(self.align)
        else:
            # Prepare dictionaries for psth():
            # One entry of spike trains for each selection,
            # an event for each segment occuring in any selection
            trains = {}
            if self.align_enabled:
                events = {}
            for s in selections:
                trains[neo.Unit(s.name)] = s.spike_trains()
                if self.align_enabled:
                    events.update(s.labeled_events(self.align))

        if events:
            for s in events:  # Align on first event in each segment
                events[s] = events[s][0]

        plot.psth(trains,
                  events,
                  start,
                  stop,
                  bin_size,
                  rate_correction=True,
                  time_unit=pq.ms,
                  bar_plot=self.diagram_type == 0,
                  progress=current.progress)
コード例 #7
0
class SignalPlotPlugin(analysis_plugin.AnalysisPlugin):
    subplots = gui_data.BoolItem('Use subplots', default=True).set_prop(
        'display', store=subplot_prop)
    subplot_titles = gui_data.BoolItem(
        'Show subplot names', default=True).set_prop('display', active=subplot_prop)
    which_signals = gui_data.ChoiceItem('Included signals',
                                        ('AnalogSignal',
                                         'AnalogSignalArray', 'Both'),
                                        default=2)
    show_events = gui_data.BoolItem('Show events', default=True)
    show_epochs = gui_data.BoolItem('Show epochs', default=True)
    multiple_plots = gui_data.BoolItem('One plot per segment', default=False)
    
    _g = gui_data.BeginGroup('Spikes')
    show_spikes = gui_data.BoolItem(
        'Show spikes').set_prop('display', store=spike_prop)
    spike_form = gui_data.ChoiceItem('Display as',
        ('Waveforms', 'Lines')).set_prop('display', active=spike_prop)
    spike_mode = gui_data.ChoiceItem('Included data', 
        ('Spikes', 'Spike Trains', 'Both'), 
        default=2).set_prop('display', active=spike_prop)
    template_mode = gui_data.BoolItem(
        'Use first spike as template').set_prop('display', active=spike_prop)
    _g_ = gui_data.EndGroup('Spikes')
    
    def get_name(self):
        return 'Signal Plot'

    def start(self, current, selections):
        current.progress.begin('Creating signal plot')

        signals = current.analog_signals_by_segment(self.which_signals + 1)

        if not signals:
            raise SpykeException('No signals selected!')
            
        # Load supplemental data
        events = None
        if self.show_events:
            current.progress.set_status('Loading events')
            events = current.events()

        epochs = None
        if self.show_epochs:
            current.progress.set_status('Loading epochs')
            epochs = current.epochs()

        spike_trains = None
        if self.show_spikes and self.spike_mode > 0:
            current.progress.set_status('Loading spike trains')
            spike_trains = current.spike_trains_by_segment()

        spikes = None
        if self.show_spikes and self.spike_mode != 1:
            current.progress.set_status('Loading spikes')
            spikes = current.spikes_by_segment()

        # Create plot
        segments = set(signals.keys())
        for seg in segments:
            current.progress.begin('Creating signal plot...')
            current.progress.set_status('Constructing plot')
            seg_events = None
            if events and events.has_key(seg):
                seg_events = events[seg]

            seg_epochs = None
            if epochs and epochs.has_key(seg):
                seg_epochs = epochs[seg]

            seg_trains = []
            if spike_trains and spike_trains.has_key(seg):
                seg_trains = spike_trains[seg]

            seg_spikes = []
            if spikes and spikes.has_key(seg):
                seg_spikes = spikes[seg]
            
            # Prepare template spikes
            if self.spike_form == 0 and self.template_mode:
                template_spikes = {}
                for s in seg_spikes[:]:
                    if s.unit not in template_spikes:
                        template_spikes[s.unit] = s
                for ts in template_spikes.itervalues():
                    seg_spikes.remove(ts)
                    
                for st in seg_trains[:]:
                    if st.unit not in template_spikes:
                        continue
                    for t in st:
                        s = copy(template_spikes[st.unit])
                        s.time = t
                        seg_spikes.append(s)
                    seg_trains.remove(st)

            plot.signals(signals[seg], events=seg_events, 
                         epochs=seg_epochs, spike_trains=seg_trains,
                         spikes=seg_spikes, use_subplots=self.subplots, 
                         show_waveforms=(self.spike_form==0),
                         subplot_names=self.subplot_titles,
                         progress=current.progress)
            
            if not self.multiple_plots:
                break
コード例 #8
0
class SpikePlotPlugin(analysis_plugin.AnalysisPlugin):
    anti_aliased = gui_data.BoolItem(
        'Antialiased lines (slow for '
        'large amounts of spikes)',
        default=True)

    _g = gui_data.BeginGroup('Include spikes from')
    spike_mode = gui_data.ChoiceItem(
        'Spikes', ('Do not include', 'Regular', 'Emphasized'), default=1)
    inc_spikes = gui_data.BoolItem('Spike Trains')
    inc_extracted = gui_data.BoolItem('Extracted from signal').set_prop(
        'display', store=extract_prop)
    length = gui_data.FloatItem('Spike length', unit='ms',
                                default=1.0).set_prop('display',
                                                      active=extract_prop)
    align = gui_data.FloatItem('Alignment offset', unit='ms',
                               default=0.5).set_prop('display',
                                                     active=extract_prop)
    _g_ = gui_data.EndGroup('Include spikes from')
    plot_type = gui_data.ChoiceItem(
        'Plot type',
        ('One plot per channel', 'One plot per unit', 'Single plot'))
    split_type = gui_data.ChoiceItem('Split channels',
                                     ('Vertically', 'Horizontally'))
    layout = gui_data.ChoiceItem('Subplot layout', ('Linear', 'Square'))
    fade = gui_data.BoolItem('Fade earlier spikes')

    def get_name(self):
        return 'Spike Waveform Plot'

    def start(self, current, selections):
        current.progress.begin('Creating spike waveform plot')
        current.progress.set_status('Loading spikes')

        spikes = {}
        strong = {}
        if self.spike_mode == 1:
            spikes = current.spikes_by_unit()
        elif self.spike_mode == 2:
            strong = current.spikes_by_unit()

        spike_trains = None
        if self.inc_spikes:
            current.progress.set_status('Loading spike trains')
            spike_trains = current.spike_trains_by_unit_and_segment()

            for u, trains in spike_trains.iteritems():
                s = []
                for st in trains.values():
                    if st.waveforms is not None:
                        s.extend(convert.spike_train_to_spikes(st))
                if not s:
                    continue
                spikes.setdefault(u, []).extend(s)

        if self.inc_extracted:
            current.progress.set_status('Extracting spikes from signals')
            signals = current.analog_signals_by_segment_and_channel(
                conversion_mode=3)
            if spike_trains is None:
                spike_trains = current.spike_trains_by_unit_and_segment()

            for u, trains in spike_trains.iteritems():
                s = []
                rcg = u.recordingchannelgroup
                for seg, train in trains.iteritems():
                    if seg not in signals:
                        continue

                    train_sigs = []
                    for rc in signals[seg]:
                        if rcg in rc.recordingchannelgroups:
                            train_sigs.append(signals[seg][rc])
                    if not train_sigs:
                        continue
                    s.extend(
                        extract_spikes(train, train_sigs, self.length * pq.ms,
                                       self.align * pq.ms))
                if not s:
                    continue
                spikes.setdefault(u, []).extend(s)

        fade = 0.2 if self.fade else 1.0
        plot.spikes(spikes,
                    self.plot_type * 2 + self.split_type + 1,
                    strong,
                    anti_alias=self.anti_aliased,
                    fade=fade,
                    subplot_layout=self.layout,
                    progress=current.progress)