Esempio n. 1
0
    def plot(self, id_list=None, v_thresh=None, display=True, kwargs={}):
        """
        Plot all cells in the AnalogSignalList defined by id_list

        Inputs:
            id_list - can be a integer (and then N cells are randomly selected) or a
                      list of ids. If None, we use all the ids of the SpikeList
            display - if True, a new figure is created. Could also be a subplot
            kwargs  - dictionary contening extra parameters that will be sent to the plot
                      function

        Examples:
            >> z = subplot(221)
            >> aslist.plot(5, display=z, kwargs={'color':'r'})
        """
        subplot = get_display(display)
        id_list = self._AnalogSignalList__sub_id_list(id_list)
        time_axis = self.time_axis()
        if not subplot or not HAVE_PYLAB:
            print PYLAB_ERROR
        else:
            xlabel = "Time (ms)"
            ylabel = "Conductance (nS)"
            set_labels(subplot, xlabel, ylabel)
            for id in id_list:
                subplot.plot(time_axis, self.analog_signals[id].signal,
                             **kwargs)
                subplot.hold(1)
            pylab.draw()
def show_results(result):
    """
    visualizes the result of the parameter search.
    Parameters:
    result - list of result dictionaries.
    """
    import numpy
    rates = numpy.sort([r['source_rate'] for r in result])
    weights = numpy.sort([r['weight'] for r in result])
    neuron_rates = numpy.zeros((len(rates), len(weights)))
    for r_i in range(len(rates)):
        for w_i in range(len(weights)):
            neuron_rates[r_i, w_i] = [r['neuron_rate'] for r in result 
                                      if (r['source_rate'] == rates[r_i])
                                      and (r['weight'] == weights[w_i])][0]
    import NeuroTools.plotting as plotting
    pylab = plotting.get_display(True)
    pylab.rcParams.update(plotting.pylab_params())
    subplot = pylab.imshow(neuron_rates, 
                           interpolation = 'nearest',  
                           origin = 'lower')
    plotting.set_labels(subplot.get_axes(), 
                        xlabel = 'rate',  
                        ylabel = 'weight')
    pylab.colorbar()
    # could add fancy xticks and yticks here
    import tempfile, os
    (fd,  figfilename) = tempfile.mkstemp(prefix = 'parameter_search_result', 
                                          suffix = '.png', 
                                          dir = os.getcwd())
    pylab.gcf().savefig(figfilename)
Esempio n. 3
0
    def plot(self, id_list=None, v_thresh=None, display=True, kwargs={}):
        """
        Plot all cells in the AnalogSignalList defined by id_list

        Inputs:
            id_list - can be a integer (and then N cells are randomly selected) or a
                      list of ids. If None, we use all the ids of the SpikeList
            display - if True, a new figure is created. Could also be a subplot
            kwargs  - dictionary contening extra parameters that will be sent to the plot
                      function

        Examples:
            >> z = subplot(221)
            >> aslist.plot(5, display=z, kwargs={'color':'r'})
        """
        subplot   = get_display(display)
        id_list   = self._AnalogSignalList__sub_id_list(id_list)
        time_axis = self.time_axis()
        if not subplot or not HAVE_PYLAB:
            print(PYLAB_ERROR)
        else:
            xlabel = "Time (ms)"
            ylabel = "Conductance (nS)"
            set_labels(subplot, xlabel, ylabel)
            for id in id_list:
                subplot.plot(time_axis, self.analog_signals[id].signal, **kwargs)
                subplot.hold(1)
            pylab.draw()
Esempio n. 4
0
def show_results(result):
    """
    visualizes the result of the parameter search.
    Parameters:
    result - list of result dictionaries.
    """
    import numpy
    t_smooth = 100.  # ms.  integration time to show fiber activity
    snrs = numpy.sort([r['snr'] for r in result])
    neuron_rates = numpy.zeros(len(snr))
    for snr_i in range(len(snrs)):
        neuron_rates[r_i] = [
            r['neuron_rate'] for r in result
            if (r['source_rate'] == snrs[snr_i])
        ][0]

    import NeuroTools.plotting as plotting
    pylab = plotting.get_display(True)
    pylab.rcParams.update(plotting.pylab_params())
    print rates, neuron_rates
    subplot = pylab.imshow(neuron_rates,
                           interpolation='nearest',
                           origin='lower')
    plotting.set_labels(subplot.get_axes(), xlabel='rate', ylabel='weight')
    pylab.colorbar()
    # could add fancy xticks and yticks here
    import tempfile, os
    (fd, figfilename) = tempfile.mkstemp(prefix='parameter_search_result',
                                         suffix='.png',
                                         dir=os.getcwd())
    pylab.gcf().savefig(figfilename)
Esempio n. 5
0
    def my_raster_plot(self, id_list=None, t_start=None, t_stop=None, display=True, kwargs={}, reduce=1, id_start=0):
        """
        Generate a raster plot for the SpikeList in a subwindow of interest,
        defined by id_list, t_start and t_stop. 
        
        Inputs:
            id_list - can be a integer (and then N cells are randomly selected) or a list of ids. If None, 
                      we use all the ids of the SpikeList
            t_start - in ms. If not defined, the one of the SpikeList object is used
            t_stop  - in ms. If not defined, the one of the SpikeList object is used
            kwargs  - dictionary contening extra parameters that will be sent to the plot 
                      function
        
        Examples:
            >> z = subplot(221)
            >> spikelist.raster_plot(display=z, kwargs={'color':'r'})
        
        See also
            SpikeTrain.raster_plot
        """
        subplot = get_display(display)
        if id_list == None: 
            id_list = self.id_list
            spk = self
        else:
            spk = self.id_slice(id_list)

        if t_start is None: t_start = spk.t_start
        if t_stop is None:  t_stop  = spk.t_stop
        if t_start != spk.t_start or t_stop != spk.t_stop:
            spk = spk.time_slice(t_start, t_stop)


        ids, spike_times = spk.convert(format="[ids, times]")
        idx = numpy.where((spike_times >= t_start) & (spike_times <= t_stop))[0]
        
        new_id_list=numpy.arange(len(id_list))+id_start
        ids_map = dict(zip(id_list, new_id_list))
        
        
        new_ids=[ids_map[id] for id in ids]
        if len(spike_times) > 0:
            subplot.plot(spike_times[::reduce], new_ids[::reduce], ',', **kwargs)
        xlabel = "Time (ms)"
        ylabel = "Neuron #"
        set_labels(subplot, xlabel, ylabel)
        
        
        
        min_id = numpy.min(new_id_list)
        max_id = numpy.max(new_id_list)
        length = t_stop - t_start
        set_axis_limits(subplot, t_start-0.05*length, t_stop+0.05*length, min_id-2, max_id+2)
        pylab.draw()
    def runTest(self):

        f = plotting.get_display(True)
        x = range(10)
        p = pylab.plot(x)
        plotting.set_labels(pylab, 'the x axis', 'the y axis')

        # set up a SimpleMultiplot with arbitrary values
        self.nrows = 1
        self.ncolumns = 1
        title = 'testMultiplot'
        xlabel = 'testXlabel'
        ylabel = 'testYlabel'
        scaling = ('linear','log')
        self.smt = plotting.SimpleMultiplot(nrows=self.nrows, ncolumns=self.ncolumns, title=title, xlabel=xlabel, ylabel=ylabel, scaling=scaling)
        plotting.set_labels(self.smt.panel(0), 'the x axis', 'the y axis')
Esempio n. 7
0
    def runTest(self):

        f = plotting.get_display(True)
        x = range(10)
        p = pylab.plot(x)
        plotting.set_labels(pylab, 'the x axis', 'the y axis')

        # set up a SimpleMultiplot with arbitrary values
        self.nrows = 1
        self.ncolumns = 1
        title = 'testMultiplot'
        xlabel = 'testXlabel'
        ylabel = 'testYlabel'
        scaling = ('linear', 'log')
        self.smt = plotting.SimpleMultiplot(nrows=self.nrows,
                                            ncolumns=self.ncolumns,
                                            title=title,
                                            xlabel=xlabel,
                                            ylabel=ylabel,
                                            scaling=scaling)
        plotting.set_labels(self.smt.panel(0), 'the x axis', 'the y axis')
Esempio n. 8
0
    def plot(self, ylabel="Analog Signal", display=True, kwargs={}):
        """
        Plot the AnalogSignal

        Inputs:
            ylabel  - A string to sepcify the label on the yaxis.
            display - if True, a new figure is created. Could also be a subplot
            kwargs  - dictionary contening extra parameters that will be sent to the plot
                      function

        Examples:
            >> z = subplot(221)
            >> signal.plot(ylabel="Vm", display=z, kwargs={'color':'r'})
        """
        subplot   = get_display(display)
        time_axis = self.time_axis()
        if not subplot or not HAVE_PYLAB:
            print(PYLAB_ERROR)
        else:
            xlabel = "Time (ms)"
            set_labels(subplot, xlabel, ylabel)
            subplot.plot(time_axis, self.signal, **kwargs)
            pylab.draw()
Esempio n. 9
0
    def plot(self, ylabel="Analog Signal", display=True, kwargs={}):
        """
        Plot the AnalogSignal

        Inputs:
            ylabel  - A string to sepcify the label on the yaxis.
            display - if True, a new figure is created. Could also be a subplot
            kwargs  - dictionary contening extra parameters that will be sent to the plot
                      function

        Examples:
            >> z = subplot(221)
            >> signal.plot(ylabel="Vm", display=z, kwargs={'color':'r'})
        """
        subplot = get_display(display)
        time_axis = self.time_axis()
        if not subplot or not HAVE_PYLAB:
            print PYLAB_ERROR
        else:
            xlabel = "Time (ms)"
            set_labels(subplot, xlabel, ylabel)
            subplot.plot(time_axis, self.signal, **kwargs)
            pylab.draw()
Esempio n. 10
0
    def plot(self, id_list=None, v_thresh=None, display=True, kwargs={}):
        """
        Plot all cells in the AnalogSignalList defined by id_list

        Inputs:
            id_list - can be a integer (and then N cells are randomly selected) or a
                      list of ids. If None, we use all the ids of the SpikeList
            v_thresh- For graphical purpose, plot a spike when Vm > V_thresh. If None,
                      just plot the raw Vm
            display - if True, a new figure is created. Could also be a subplot
            kwargs  - dictionary contening extra parameters that will be sent to the plot
                      function

        Examples:
            >> z = subplot(221)
            >> aslist.plot(5, v_thresh = -50, display=z, kwargs={'color':'r'})
        """
        subplot = get_display(display)
        id_list = self._AnalogSignalList__sub_id_list(id_list)
        time_axis = self.time_axis()
        if not subplot or not HAVE_MATPLOTLIB:
            print MATPLOTLIB_ERROR
        else:
            xlabel = "Time (ms)"
            ylabel = "Membrane Potential (mV)"
            set_labels(subplot, xlabel, ylabel)
            for id in id_list:
                to_be_plot = self.analog_signals[id].signal
                if v_thresh is not None:
                    to_be_plot = pylab.where(to_be_plot >= v_thresh - 0.02,
                                             v_thresh + 0.5, to_be_plot)
                if len(time_axis) > len(to_be_plot):
                    time_axis = time_axis[:-1]
                if len(to_be_plot) > len(time_axis):
                    to_be_plot = to_be_plot[:-1]
                subplot.plot(time_axis, to_be_plot, **kwargs)
                subplot.hold(1)
Esempio n. 11
0
    def plot(self, id_list=None, v_thresh=None, display=True, kwargs={}):
        """
        Plot all cells in the AnalogSignalList defined by id_list

        Inputs:
            id_list - can be a integer (and then N cells are randomly selected) or a
                      list of ids. If None, we use all the ids of the SpikeList
            v_thresh- For graphical purpose, plot a spike when Vm > V_thresh. If None,
                      just plot the raw Vm
            display - if True, a new figure is created. Could also be a subplot
            kwargs  - dictionary contening extra parameters that will be sent to the plot
                      function

        Examples:
            >> z = subplot(221)
            >> aslist.plot(5, v_thresh = -50, display=z, kwargs={'color':'r'})
        """
        subplot   = get_display(display)
        id_list   = self._AnalogSignalList__sub_id_list(id_list)
        time_axis = self.time_axis()
        if not subplot or not HAVE_MATPLOTLIB:
            print(MATPLOTLIB_ERROR)
        else:
            xlabel = "Time (ms)"
            ylabel = "Membrane Potential (mV)"
            set_labels(subplot, xlabel, ylabel)
            for id in id_list:
                to_be_plot = self.analog_signals[id].signal
                if v_thresh is not None:
                    to_be_plot = pylab.where(to_be_plot>=v_thresh-0.02, v_thresh+0.5, to_be_plot)
                if len(time_axis) > len(to_be_plot):
                    time_axis = time_axis[:-1]
                if len(to_be_plot) > len(time_axis):
                    to_be_plot = to_be_plot[:-1]
                subplot.plot(time_axis, to_be_plot, **kwargs)
                subplot.hold(1)
Esempio n. 12
0
def crosscorrelate(sua1, sua2, lag=None, n_pred=1, predictor=None,
                   display=False, kwargs={}):
    """Cross-correlation between two series of discrete events (e.g. spikes).

    Calculates the cross-correlation between
    two vectors containing event times.
    Returns ``(differeces, pred, norm)``. See below for details.

    Adapted from original script written by Martin P. Nawrot for the
    FIND MATLAB toolbox [1]_.

    Parameters
    ----------
    sua1, sua2 : 1D row or column `ndarray` or `SpikeTrain`
        Event times. If sua2 == sua1, the result is the autocorrelogram.
    lag : float
        Lag for which relative event timing is considered
        with a max difference of +/- lag. A default lag is computed
        from the inter-event interval of the longer of the two sua
        arrays.
    n_pred : int
        Number of surrogate compilations for the predictor. This
        influences the total length of the predictor output array
    predictor : {None, 'shuffle'}
        Determines the type of bootstrap predictor to be used.
        'shuffle' shuffles interevent intervals of the longer input array
        and calculates relative differences with the shorter input array.
        `n_pred` determines the number of repeated shufflings, resulting
        differences are pooled from all repeated shufflings.
    display : boolean
        If True the corresponding plots will be displayed. If False,
        int, int_ and norm will be returned.
    kwargs : dict
        Arguments to be passed to np.histogram.

    Returns
    -------
    differences : np array
        Accumulated differences of events in `sua1` minus the events in
        `sua2`. Thus positive values relate to events of `sua2` that
        lead events of `sua1`. Units are the same as the input arrays.
    pred : np array
        Accumulated differences based on the prediction method.
        The length of `pred` is ``n_pred * length(differences)``. Units are
        the same as the input arrays.
    norm : float
        Normalization factor used to scale the bin heights in `differences` and
        `pred`. ``differences/norm`` and ``pred/norm`` correspond to the linear
        correlation coefficient.

    Examples
    --------
    >> crosscorrelate(np_array1, np_array2)
    >> crosscorrelate(spike_train1, spike_train2)
    >> crosscorrelate(spike_train1, spike_train2, lag = 150.0)
    >> crosscorrelate(spike_train1, spike_train2, display=True,
                      kwargs={'bins':100})

    See also
    --------
    ccf

    .. [1] Meier R, Egert U, Aertsen A, Nawrot MP, "FIND - a unified framework
       for neural data analysis"; Neural Netw. 2008 Oct; 21(8):1085-93.

    """
    assert predictor is 'shuffle' or predictor is None, "predictor must be \
    either None or 'shuffle'. Other predictors are not yet implemented."

    #Check whether sua1 and sua2 are SpikeTrains or arrays
    sua = []
    for x in (sua1, sua2):
        #if isinstance(x, SpikeTrain):
        if hasattr(x, 'spike_times'):
            sua.append(x.spike_times)
        elif x.ndim == 1:
            sua.append(x)
        elif x.ndim == 2 and (x.shape[0] == 1 or x.shape[1] == 1):
            sua.append(x.ravel())
        else:
            raise TypeError("sua1 and sua2 must be either instances of the" \
                            "SpikeTrain class or column/row vectors")
    sua1 = sua[0]
    sua2 = sua[1]

    if sua1.size < sua2.size:
        if lag is None:
            lag = np.ceil(10*np.mean(np.diff(sua1)))
        reverse = False
    else:
        if lag is None:
            lag = np.ceil(20*np.mean(np.diff(sua2)))
        sua1, sua2 = sua2, sua1
        reverse = True

    #construct predictor
    if predictor is 'shuffle':
        isi = np.diff(sua2)
        sua2_ = np.array([])
        for ni in xrange(1,n_pred+1):
            idx = np.random.permutation(isi.size-1)
            sua2_ = np.append(sua2_, np.add(np.insert(
                (np.cumsum(isi[idx])), 0, 0), sua2.min() + (
                np.random.exponential(isi.mean()))))

    #calculate cross differences in spike times
    differences = np.array([])
    pred = np.array([])
    for k in xrange(0, sua1.size):
        differences = np.append(differences, sua1[k] - sua2[np.nonzero(
            (sua2 > sua1[k] - lag) & (sua2 < sua1[k] + lag))])
    if predictor == 'shuffle':
        for k in xrange(0, sua1.size):
            pred = np.append(pred, sua1[k] - sua2_[np.nonzero(
                (sua2_ > sua1[k] - lag) & (sua2_ < sua1[k] + lag))])
    if reverse is True:
        differences = -differences
        pred = -pred

    norm = np.sqrt(sua1.size * sua2.size)

    # Plot the results if display=True
    if display:
        subplot = get_display(display)
        if not subplot or not HAVE_PYLAB:
            return differences, pred, norm
        else:
            # Plot the cross-correlation
            try:
                counts, bin_edges = np.histogram(differences, **kwargs)
                edge_distances = np.diff(bin_edges)
                bin_centers = bin_edges[1:] - edge_distances/2
                counts = counts / norm
                xlabel = "Time"
                ylabel = "Cross-correlation coefficient"
                #NOTE: the x axis corresponds to the upper edge of each bin
                subplot.plot(bin_centers, counts, label='cross-correlation', color='b')
                if predictor is None:
                    set_labels(subplot, xlabel, ylabel)
                    pylab.draw()
                elif predictor is 'shuffle':
                    # Plot the predictor
                    norm_ = norm * n_pred
                    counts_, bin_edges_ = np.histogram(pred, **kwargs)
                    counts_ = counts_ / norm_
                    subplot.plot(bin_edges_[1:], counts_, label='predictor')
                    subplot.legend()
                    pylab.draw()
            except ValueError:
                print "There are no correlated events within the selected lag"\
                " window of %s" % lag
    else:
        return differences, pred, norm
Esempio n. 13
0
    def event_triggered_average(self, events, average = True, t_min = 0, t_max = 100, display = False, with_time = False, kwargs={}):
        """
        Return the spike triggered averaged of an analog signal according to selected events,
        on a time window t_spikes - tmin, t_spikes + tmax
        Can return either the averaged waveform (average = True), or an array of all the
        waveforms triggered by all the spikes.

        Inputs:
            events  - Can be a SpikeTrain object (and events will be the spikes) or just a list
                      of times
            average - If True, return a single vector of the averaged waveform. If False,
                      return an array of all the waveforms.
            t_min   - Time (>0) to average the signal before an event, in ms (default 0)
            t_max   - Time (>0) to average the signal after an event, in ms  (default 100)
            display - if True, a new figure is created. Could also be a subplot.
            kwargs  - dictionary contening extra parameters that will be sent to the plot
                      function

        Examples:
            >> vm.event_triggered_average(spktrain, average=False, t_min = 50, t_max = 150)
            >> vm.event_triggered_average(spktrain, average=True)
            >> vm.event_triggered_average(range(0,1000,10), average=False, display=True)
        """

        if isinstance(events, SpikeTrain):
            events = events.spike_times
            ylabel = "Spike Triggered Average"
        else:
            assert numpy.iterable(events), "events should be a SpikeTrain object or an iterable object"
            ylabel = "Event Triggered Average"
        assert (t_min >= 0) and (t_max >= 0), "t_min and t_max should be greater than 0"
        assert len(events) > 0, "events should not be empty and should contained at least one element"
        time_axis = numpy.linspace(-t_min, t_max, (t_min+t_max)/self.dt)
        N         = len(time_axis)
        Nspikes   = 0.
        subplot   = get_display(display)
        if average:
            result = numpy.zeros(N, float)
        else:
            result = []

        # recalculate everything into timesteps, is more stable against rounding errors
        # and subsequent cutouts with different sizes
        events = numpy.floor(numpy.array(events)/self.dt)
        t_min_l = numpy.floor(t_min/self.dt)
        t_max_l = numpy.floor(t_max/self.dt)
        t_start = numpy.floor(self.t_start/self.dt)
        t_stop = numpy.floor(self.t_stop/self.dt)

        for spike in events:
            if ((spike-t_min_l) >= t_start) and ((spike+t_max_l) < t_stop):
                spike = spike - t_start
                if average:
                    result += self.signal[(spike-t_min_l):(spike+t_max_l)]
                else:
                    result.append(self.signal[(spike-t_min_l):(spike+t_max_l)])
                Nspikes += 1
        if average:
            result = result/Nspikes
        else:
            result = numpy.array(result)

        if not subplot or not HAVE_PYLAB:
            if with_time:
                return result, time_axis
            else:
                return result
        else:
            xlabel = "Time (ms)"
            set_labels(subplot, xlabel, ylabel)
            if average:
                subplot.plot(time_axis, result, **kwargs)
            else:
                for idx in xrange(len(result)):
                    subplot.plot(time_axis, result[idx,:], c='0.5', **kwargs)
                    subplot.hold(1)
                result = numpy.sum(result, axis=0)/Nspikes
                subplot.plot(time_axis, result, c='k', **kwargs)
            xmin, xmax, ymin, ymax = subplot.axis()

            subplot.plot([0,0],[ymin, ymax], c='r')
            set_axis_limits(subplot, -t_min, t_max, ymin, ymax)
            pylab.draw()
Esempio n. 14
0
    def event_triggered_average(self,
                                events,
                                average=True,
                                t_min=0,
                                t_max=100,
                                display=False,
                                with_time=False,
                                kwargs={}):
        """
        Return the spike triggered averaged of an analog signal according to selected events,
        on a time window t_spikes - tmin, t_spikes + tmax
        Can return either the averaged waveform (average = True), or an array of all the
        waveforms triggered by all the spikes.

        Inputs:
            events  - Can be a SpikeTrain object (and events will be the spikes) or just a list
                      of times
            average - If True, return a single vector of the averaged waveform. If False,
                      return an array of all the waveforms.
            t_min   - Time (>0) to average the signal before an event, in ms (default 0)
            t_max   - Time (>0) to average the signal after an event, in ms  (default 100)
            display - if True, a new figure is created. Could also be a subplot.
            kwargs  - dictionary contening extra parameters that will be sent to the plot
                      function

        Examples:
            >> vm.event_triggered_average(spktrain, average=False, t_min = 50, t_max = 150)
            >> vm.event_triggered_average(spktrain, average=True)
            >> vm.event_triggered_average(range(0,1000,10), average=False, display=True)
        """

        if isinstance(events, SpikeTrain):
            events = events.spike_times
            ylabel = "Spike Triggered Average"
        else:
            assert numpy.iterable(
                events
            ), "events should be a SpikeTrain object or an iterable object"
            ylabel = "Event Triggered Average"
        assert (t_min >= 0) and (t_max >=
                                 0), "t_min and t_max should be greater than 0"
        assert len(
            events
        ) > 0, "events should not be empty and should contained at least one element"
        time_axis = numpy.linspace(-t_min, t_max, (t_min + t_max) / self.dt)
        N = len(time_axis)
        Nspikes = 0.
        subplot = get_display(display)
        if average:
            result = numpy.zeros(N, float)
        else:
            result = []

        # recalculate everything into timesteps, is more stable against rounding errors
        # and subsequent cutouts with different sizes
        events = numpy.floor(numpy.array(events) / self.dt)
        t_min_l = numpy.floor(t_min / self.dt)
        t_max_l = numpy.floor(t_max / self.dt)
        t_start = numpy.floor(self.t_start / self.dt)
        t_stop = numpy.floor(self.t_stop / self.dt)

        for spike in events:
            if ((spike - t_min_l) >= t_start) and ((spike + t_max_l) < t_stop):
                spike = spike - t_start
                if average:
                    result += self.signal[(spike - t_min_l):(spike + t_max_l)]
                else:
                    result.append(self.signal[(spike - t_min_l):(spike +
                                                                 t_max_l)])
                Nspikes += 1
        if average:
            result = result / Nspikes
        else:
            result = numpy.array(result)

        if not subplot or not HAVE_PYLAB:
            if with_time:
                return result, time_axis
            else:
                return result
        else:
            xlabel = "Time (ms)"
            set_labels(subplot, xlabel, ylabel)
            if average:
                subplot.plot(time_axis, result, **kwargs)
            else:
                for idx in xrange(len(result)):
                    subplot.plot(time_axis, result[idx, :], c='0.5', **kwargs)
                    subplot.hold(1)
                result = numpy.sum(result, axis=0) / Nspikes
                subplot.plot(time_axis, result, c='k', **kwargs)
            xmin, xmax, ymin, ymax = subplot.axis()

            subplot.plot([0, 0], [ymin, ymax], c='r')
            set_axis_limits(subplot, -t_min, t_max, ymin, ymax)
            pylab.draw()
Esempio n. 15
0
def crosscorrelate(sua1,
                   sua2,
                   lag=None,
                   n_pred=1,
                   predictor=None,
                   display=False,
                   kwargs={}):
    """Cross-correlation between two series of discrete events (e.g. spikes).

    Calculates the cross-correlation between
    two vectors containing event times.
    Returns ``(differeces, pred, norm)``. See below for details.

    Adapted from original script written by Martin P. Nawrot for the
    FIND MATLAB toolbox [1]_.

    Parameters
    ----------
    sua1, sua2 : 1D row or column `ndarray` or `SpikeTrain`
        Event times. If sua2 == sua1, the result is the autocorrelogram.
    lag : float
        Lag for which relative event timing is considered
        with a max difference of +/- lag. A default lag is computed
        from the inter-event interval of the longer of the two sua
        arrays.
    n_pred : int
        Number of surrogate compilations for the predictor. This
        influences the total length of the predictor output array
    predictor : {None, 'shuffle'}
        Determines the type of bootstrap predictor to be used.
        'shuffle' shuffles interevent intervals of the longer input array
        and calculates relative differences with the shorter input array.
        `n_pred` determines the number of repeated shufflings, resulting
        differences are pooled from all repeated shufflings.
    display : boolean
        If True the corresponding plots will be displayed. If False,
        int, int_ and norm will be returned.
    kwargs : dict
        Arguments to be passed to np.histogram.

    Returns
    -------
    differences : np array
        Accumulated differences of events in `sua1` minus the events in
        `sua2`. Thus positive values relate to events of `sua2` that
        lead events of `sua1`. Units are the same as the input arrays.
    pred : np array
        Accumulated differences based on the prediction method.
        The length of `pred` is ``n_pred * length(differences)``. Units are
        the same as the input arrays.
    norm : float
        Normalization factor used to scale the bin heights in `differences` and
        `pred`. ``differences/norm`` and ``pred/norm`` correspond to the linear
        correlation coefficient.

    Examples
    --------
    >> crosscorrelate(np_array1, np_array2)
    >> crosscorrelate(spike_train1, spike_train2)
    >> crosscorrelate(spike_train1, spike_train2, lag = 150.0)
    >> crosscorrelate(spike_train1, spike_train2, display=True,
                      kwargs={'bins':100})

    See also
    --------
    ccf

    .. [1] Meier R, Egert U, Aertsen A, Nawrot MP, "FIND - a unified framework
       for neural data analysis"; Neural Netw. 2008 Oct; 21(8):1085-93.

    """
    assert predictor is 'shuffle' or predictor is None, "predictor must be \
    either None or 'shuffle'. Other predictors are not yet implemented."

    #Check whether sua1 and sua2 are SpikeTrains or arrays
    sua = []
    for x in (sua1, sua2):
        #if isinstance(x, SpikeTrain):
        if hasattr(x, 'spike_times'):
            sua.append(x.spike_times)
        elif x.ndim == 1:
            sua.append(x)
        elif x.ndim == 2 and (x.shape[0] == 1 or x.shape[1] == 1):
            sua.append(x.ravel())
        else:
            raise TypeError("sua1 and sua2 must be either instances of the" \
                            "SpikeTrain class or column/row vectors")
    sua1 = sua[0]
    sua2 = sua[1]

    if sua1.size < sua2.size:
        if lag is None:
            lag = np.ceil(10 * np.mean(np.diff(sua1)))
        reverse = False
    else:
        if lag is None:
            lag = np.ceil(20 * np.mean(np.diff(sua2)))
        sua1, sua2 = sua2, sua1
        reverse = True

    #construct predictor
    if predictor is 'shuffle':
        isi = np.diff(sua2)
        sua2_ = np.array([])
        for ni in xrange(1, n_pred + 1):
            idx = np.random.permutation(isi.size - 1)
            sua2_ = np.append(
                sua2_,
                np.add(np.insert((np.cumsum(isi[idx])), 0, 0),
                       sua2.min() + (np.random.exponential(isi.mean()))))

    #calculate cross differences in spike times
    differences = np.array([])
    pred = np.array([])
    for k in xrange(0, sua1.size):
        differences = np.append(
            differences, sua1[k] -
            sua2[np.nonzero((sua2 > sua1[k] - lag) & (sua2 < sua1[k] + lag))])
    if predictor == 'shuffle':
        for k in xrange(0, sua1.size):
            pred = np.append(
                pred, sua1[k] - sua2_[np.nonzero((sua2_ > sua1[k] - lag)
                                                 & (sua2_ < sua1[k] + lag))])
    if reverse is True:
        differences = -differences
        pred = -pred

    norm = np.sqrt(sua1.size * sua2.size)

    # Plot the results if display=True
    if display:
        subplot = get_display(display)
        if not subplot or not HAVE_PYLAB:
            return differences, pred, norm
        else:
            # Plot the cross-correlation
            try:
                counts, bin_edges = np.histogram(differences, **kwargs)
                edge_distances = np.diff(bin_edges)
                bin_centers = bin_edges[1:] - edge_distances / 2
                counts = counts / norm
                xlabel = "Time"
                ylabel = "Cross-correlation coefficient"
                #NOTE: the x axis corresponds to the upper edge of each bin
                subplot.plot(bin_centers,
                             counts,
                             label='cross-correlation',
                             color='b')
                if predictor is None:
                    set_labels(subplot, xlabel, ylabel)
                    pylab.draw()
                elif predictor is 'shuffle':
                    # Plot the predictor
                    norm_ = norm * n_pred
                    counts_, bin_edges_ = np.histogram(pred, **kwargs)
                    counts_ = counts_ / norm_
                    subplot.plot(bin_edges_[1:], counts_, label='predictor')
                    subplot.legend()
                    pylab.draw()
            except ValueError:
                print("There are no correlated events within the selected lag"\
                " window of %s" % lag)
    else:
        return differences, pred, norm