Exemplo n.º 1
0
def plot_trig_(trig_data, window=None, ax=None, error_fn='var', **kwargs):
    '''
    Standard plotting methods for "triggered" data created by trig_ function above.
    '''
    if callable(error_fn):
        pass
    elif error_fn == 'var':
        error_fn = np.var
    elif error_fn == 'std':
        error_fn = np.std
    elif error_fn == 'sem':
        error_fn = stats.sem
    elif error_fn == None:
        pass
    else:
        raise ValueError('Unrecognized error_fn: %r' % error_fn)

    if ax == None:
        import plotutil
        plt.figure()
        axes = plotutil.subplots(1, 1, hold=True)
        ax = axes[0, 0]

    mean_ = np.mean(trig_data, axis=2).ravel()

    if window == None:
        window = np.arange(len(mean_))

    if error_fn is not None:
        import plotutil
        err_ = error_fn(trig_data, axis=2).ravel()
        plotutil.error_line(ax, window, mean_, err_, **kwargs)
    else:
        ax.plot(window, mean_, **kwargs)
    return ax
 def plot_before_and_after_C(self):
     dec_before = self.decoder
     dec_after = dbfn.get_decoders_trained_in_block(self.id)
     plt.figure()
     axes = plotutil.subplots(1,2,return_flat=True, hold=True)
     dec_before.plot_C(ax=axes[0])
     dec_after.plot_C(ax=axes[1])
    def plot_C_hist(self, param_fns=[lambda C_hist: C_hist[:,:,3], lambda C_hist: C_hist[:,:,5], lambda C_hist: C_hist[:,:,6], lambda C_hist: np.sqrt(C_hist[:, :, 3]**2 + C_hist[:,:,5]**2)],
            labels=['Change in x-vel tuning', 'Change in z-vel tuning', 'Change in baseline', 'Change in mod. depth']):
        '''
        Plot parameter trajectories for C
        '''
        C_hist = self.hdf.root.task[1:]['filt_C']
        n_units = C_hist.shape[1]
        n_blocks = int(np.ceil(float(n_units)/7))

        fig = plt.figure(facecolor='w', figsize=(8./3*len(param_fns), 2*n_blocks))
        axes = plotutil.subplots(n_blocks, len(param_fns), y=0.01)
        #, bottom_offset=0.01)
        #fig = plt.figure(figsize=(8, 2*n_units), facecolor='w')
        #axes = plotutil.subplots(n_units, len(param_fns), y=0.01) #, bottom_offset=0.01)

        for m, fn in enumerate(param_fns):
            for k in range(n_blocks):
                sl = slice(k*7, (k+1)*7, None)
                param_hist = fn(C_hist)[:,sl]
                param_hist_diff = param_hist - param_hist[0,:]
                axes[k,m].plot(param_hist_diff)
                axes[k,m].set_xticklabels([])
                if m == 0:
                    plotutil.ylabel(axes[k,m], 'Units %d-%d' % (sl.start, sl.stop-1))
                if k == n_blocks - 1:
                    plotutil.xlabel(axes[k,m], labels[m])

            lims = np.vstack(map(lambda ax: ax.get_ylim(), axes[:,m]))
            ylim = min(lims[:,0]), max(lims[:,1])
            plotutil.set_axlim(axes[:,m], ylim, axis='y')

        self.save_plot('clda_param_hist')
 def plot_loop_times(self, intended_update_rate=60.):
     loop_times = self.hdf.root.task[:]['loop_time'].ravel()
     plt.figure()
     axes = plotutil.subplots(1, 1, return_flat=True)
     plotutil.histogram_line(axes[0], loop_times, np.arange(0, 0.050, 0.0005))
     axes[0].axvline(1./intended_update_rate, color='black', linestyle='--')
     self.save_plot('loop_times')
 def plot_C_hist_pds(self):
     C_hist_plot = self.hdf.root.task[1:10000:sec_per_min*self.update_rate]['filt_C']
     n_plots = C_hist_plot.shape[0]
     plt.figure(figsize=(3, 3*n_plots))
     axes = plotutil.subplots(n_plots, 1, return_flat=True, hold=True, aspect=1)
     for k in range(n_plots):
         self.decoder.plot_pds(C_hist_plot[k,:,:], ax=axes[k])
     self.save_plot('clda_param_hist_pds')
    def plot_rewards_per_min(self, ax=None, show=False, max_ylim=None, save=True, **kwargs):
        '''
        Make a plot of the rewards per minute
        '''
        import plotutil
        tvec, rewards_per_min = self.get_rewards_per_min(**kwargs)
        rewards_per_min = rewards_per_min[::900]
        tvec = tvec[::900]

        # find the time when CLDA turns off
        task_msgs = self.hdf.root.task_msgs[:]
        clda_stop = self.clda_stop_time

        if ax == None:
            plt.figure(figsize=(4,3))
            axes = plotutil.subplots(1, 1, return_flat=True, hold=True, left_offset=0.1)
            ax = axes[0]
        else:
            save = False

        try:
            # find the time when the assist turns off
            assist_level = self.hdf.root.task[:]['assist_level'].ravel()
            assist_stop = np.nonzero(assist_level == 0)[0][0]

            assist_stop *= min_per_sec * 1./self.update_rate # convert to min

            ax.axvline(assist_stop, label='Assist off', color='green', linewidth=2)
        except:
            pass
        ax.axvline(clda_stop, label='CLDA off', color='blue', linewidth=2, linestyle='--')
        ax.plot(tvec * min_per_sec, rewards_per_min, color='black', linewidth=2)
        if max_ylim == None:
            max_ylim = int(max(15, int(np.ceil(max(rewards_per_min)))))
        max_xlim = int(np.ceil(max(tvec * min_per_sec)))
        # plotutil.set_axlim(ax, [0, max_ylim], labels=range(max_ylim+1), axis='y')
        # plotutil.set_axlim(ax, [0, max_ylim], labels=range(0, max_ylim+1), axis='y')
        plotutil.set_xlim(ax, [0, max_xlim])
        plotutil.ylabel(ax, 'Rewards/min', offset=-0.08)
        plotutil.xlabel(ax, 'Time during block (min)')

        plotutil.legend(ax)
        ax.grid()

        if save: self.save_plot('rewards_per_min')

        if show:
            plt.show()
    def plot_wfs(self, chan, ax_sorted=None, ax_unsorted=None):
        if ax_unsorted == None or ax_sorted == None:
            plt.figure()
            ax_sorted, ax_unsorted = plotutil.subplots(1, 2, hold=True, return_flat=True)

        colors = ['red', 'blue', 'green', 'black', 'magenta']
        for unit in [1, 2, 3, 4]:
            inds, = np.nonzero((self.spike_timestamps['chan'] == chan) * (self.spike_timestamps['unit'] == unit))
            if len(inds > 0):
                unit_wfs = self.spike_waveforms[inds]
                xlist = []
                ylist = []
                d = np.arange(32)
                for wf in unit_wfs:
                    xlist.extend(d)
                    xlist.append(None)
                    ylist.extend(wf)
                    ylist.append(None)

                ax_sorted.plot(xlist, ylist, color=colors[unit])

        for unit in [0]:
            inds, = np.nonzero((self.spike_timestamps['chan'] == chan) * (self.spike_timestamps['unit'] == unit))
            if len(inds) > 0:
                unit_wfs = self.spike_waveforms[inds]
                # ax_unsorted.plot(unit_wfs.T, color=colors[unit])
                xlist = []
                ylist = []
                d = np.arange(32)
                for wf in unit_wfs:
                    xlist.extend(d)
                    xlist.append(None)
                    ylist.extend(wf)
                    ylist.append(None)

                ax_unsorted.plot(xlist, ylist, color=colors[unit])                

        ## Set y limits of plot
        sorted_ylim = ax_sorted.get_ylim()
        unsorted_ylim = ax_unsorted.get_ylim()
        ylim = (min(sorted_ylim[0], unsorted_ylim[0]), max(sorted_ylim[1], unsorted_ylim[1]))
        ax_unsorted.set_ylim(ylim)
        ax_sorted.set_ylim(ylim)
Exemplo n.º 8
0
def plot_pair(perf, labels=['KF', 'PPF']):
    fig = plt.figure(figsize=(4, 8))
    axes = plotutil.subplots(4, 1, hold=True)

    # plot KF, PPFLC
    signs = [-1, +1, +1, -1]
    for k, key in enumerate(
        ['trials_per_min', 'ME', 'reach_time', 'perc_correct']):
        axes[k, 0].scatter(1 * np.ones(len(perf[key][0])), perf[key][0])
        axes[k, 0].scatter(2 * np.ones(len(perf[key][0])), perf[key][1])
        for a, b in izip(*perf[key]):
            axes[k, 0].plot([1., 2.], [a, b], color='black')
        axes[k, 0].set_xticks([1., 2.])
        axes[k, 0].set_xticklabels(labels)
        plotutil.set_xlim(axes[k, 0], axis='y')
        plotutil.ylabel(axes[k, 0], key)
        h, p_val = ttest_rel(perf[key][0], perf[key][1])
        print h, p_val
        if signs[k] * h > 0:
            p_val = p_val / 2
        else:
            pass
        plotutil.xlabel(axes[k, 0], 'p=%g, N=%d' % (p_val, len(perf[key][0])))
Exemplo n.º 9
0
#blocks = [2218, 2248, 2216]
#blocks = [2260, 2261, 2262, 2264, 2265, 2266, 2267]
blocks = [2274, 2275, 2276, 2277, 2281, 2282]
task_entry_set = dbfn.TaskEntrySet(blocks)

bins = np.arange(0.5, 40, 0.5)

labels = []
for te in task_entry_set.task_entries:
    if 'PPF' in te.decoder_type: labels.append(str(te.params['tau']))
    elif 'KF' in te.decoder_type: labels.append('KF')

plt.close('all')
plt.figure(facecolor='w')
axes = plotutil.subplots(2, 1, return_flat=True, hold=True)
#task_entry_set.histogram(lambda te: te.intended_kin_norm(slice(3,6)), axes[0], bins)
task_entry_set.histogram(lambda te: te.cursor_speed(), axes[1], bins, labels=labels)
#task_entry_set.histogram(lambda te: te.cursor_speed('assist_off'), axes[1], bins, labels=labels)
plt.legend()

axes[0].set_xticks([])

plotutil.xlabel(axes[0], 'Est. of intended speed (cm/s)')

plotutil.xlabel(axes[1], 'Actual speed during CLDA (no assist) (cm/s)')
plotutil.set_xlim(axes[0], axis='y')
plotutil.set_xlim(axes[1], axis='y')
plotutil.ylabel(axes, 'Density')
plotutil.set_xlim(axes, [0,40])