def plot_model_results(self, model=None, fit=None):
     if model is None:
         if self._last_model_fit is None:
             raise Exception("Must run fit_release_model before plotting results.")
         model, fit = self._last_model_fit
     spike_sets = self.spike_sets
     
     rel_plots = PlotGrid()
     rel_plots.set_shape(2, 1)
     ind_plot = rel_plots[0, 0]
     ind_plot.setTitle('Release model fit - induction frequency')
     ind_plot.setLabels(bottom=('time', 's'), left='relative amplitude')
     rec_plot = rel_plots[1, 0]
     rec_plot.setTitle('Release model fit - recovery delay')
     rec_plot.setLabels(bottom=('time', 's'), left='relative amplitude')
     
     ind_plot.setLogMode(x=True, y=False)
     rec_plot.setLogMode(x=True, y=False)
     ind_plot.setXLink(rec_plot)
     
     for i,stim_params in enumerate(self.stim_param_order):
         x,y = spike_sets[i]
         output = model.eval(x, fit.values(), dt=0.5)
         y1 = output[:,1]
         x1 = output[:,0]
         if stim_params[1] - 0.250 < 5e-3:
             ind_plot.plot((x+10)/1000., y, pen=None, symbol='o', symbolBrush=(i,10))
             ind_plot.plot((x1+10)/1000., y1, pen=(i,10))
         if stim_params[0] == 50:
             rec_plot.plot((x+10)/1000., y, pen=None, symbol='o', symbolBrush=(i,10))
             rec_plot.plot((x1+10)/1000., y1, pen=(i,10))
     
     rel_plots.show()
     return rel_plots
예제 #2
0
def summary_plot_pulse(grand_trace, feature_list, feature_mean, labels, titles, i, plot=None, color=None, name=None):
    if type(feature_list) is tuple:
        n_features = len(feature_list)
    else:
        n_features = 1
    if plot is None:
        plot = PlotGrid()
        plot.set_shape(n_features, 2)
        plot.show()
        for g in range(n_features):
            plot[g, 1].addLegend()
            plot[g, 1].setLabels(left=('Vm', 'V'))
            plot[g, 1].setLabels(bottom=('t', 's'))

    for feature in range(n_features):
        if n_features > 1:
            features = feature_list[feature]
            mean = feature_mean[feature]
            label = labels[feature]
            title = titles[feature]
        else:
            features = feature_list
            mean = feature_mean
            label = labels
            title = titles
        plot[feature, 0].setLabels(left=(label[0], label[1]))
        plot[feature, 0].hideAxis('bottom')
        plot[feature, 0].setTitle(title)
        plot[feature, 1].plot(grand_trace.time_values, grand_trace.data, pen=color, name=name)
        dx = pg.pseudoScatter(np.array(features).astype(float), 0.3, bidir=True)
        plot[feature, 0].plot((0.3 * dx / dx.max()) + i, features, pen=None, symbol='x', symbolSize=5, symbolBrush=color,
                         symbolPen=None)
        plot[feature, 0].plot([i], [mean], pen=None, symbol='o', symbolBrush=color, symbolPen='w',
                     symbolSize=10)
    return plot
    def plot_train_responses(self, plot_grid=None):
        """
        Plot individual and averaged train responses for each set of stimulus parameters.

        Return a new PlotGrid.
        """
        train_responses = self.train_responses

        if plot_grid is None:
            train_plots = PlotGrid()
        else:
            train_plots = plot_grid
        train_plots.set_shape(len(self.stim_param_order), 2)

        for i,stim_params in enumerate(train_responses.keys()):
            
            # Collect and plot average traces covering the induction and recovery 
            # periods for this set of stim params
            ind_group = train_responses[stim_params][0]
            rec_group = train_responses[stim_params][1]
            
            for j in range(len(ind_group)):
                ind = ind_group.responses[j]
                rec = rec_group.responses[j]
                base = np.median(ind_group.baselines[j].data)
                train_plots[i,0].plot(ind.time_values, ind.data - base, pen=(128, 128, 128, 100))
                train_plots[i,1].plot(rec.time_values, rec.data - base, pen=(128, 128, 128, 100))
            ind_avg = ind_group.bsub_mean()
            rec_avg = rec_group.bsub_mean()

            ind_freq, rec_delay, holding = stim_params
            rec_delay = np.round(rec_delay, 2)
            train_plots[i,0].plot(ind_avg.time_values, ind_avg.data, pen='g', antialias=True)
            train_plots[i,1].plot(rec_avg.time_values, rec_avg.data, pen='g', antialias=True)
            train_plots[i,0].setLabels(left=('Vm', 'V'))
            label = pg.LabelItem("ind: %0.0f  rec: %0.0f  hold: %0.0f" % (ind_freq, rec_delay*1000, holding*1000))
            label.setParentItem(train_plots[i,0].vb)
            train_plots[i,0].label = label
            
        train_plots.show()
        train_plots.setYLink(train_plots[0,0])
        for i in range(train_plots.shape[0]):
            train_plots[i,0].setXLink(train_plots[0,0])
            train_plots[i,1].setXLink(train_plots[0,1])
        train_plots.grid.ci.layout.setColumnStretchFactor(0, 3)
        train_plots.grid.ci.layout.setColumnStretchFactor(1, 2)
        train_plots.setClipToView(False)  # has a bug :(
        train_plots.setDownsampling(True, True, 'peak')
        
        return train_plots
예제 #4
0
def trace_average_matrix(expts, **kwds):
    types = ['tlx3', 'sim1', 'pvalb', 'sst', 'vip']
    results = {}
    plots = PlotGrid()
    plots.set_shape(len(types), len(types))
    indplots = PlotGrid()
    indplots.set_shape(len(types), len(types))
    plots.show()
    indplots.show()
    for i, pre_type in enumerate(types):
        for j, post_type in enumerate(types):
            avg_plot = plots[i, j]
            ind_plot = indplots[i, j]
            avg = plot_trace_average(all_expts, pre_type, post_type, avg_plot, ind_plot, **kwds)
            results[(pre_type, post_type)] = avg
    return results
def summary_plot_pulse(feature_list, labels, titles, i, median=False, grand_trace=None, plot=None, color=None, name=None):
    if type(feature_list) is tuple:
        n_features = len(feature_list)
    else:
        n_features = 1
    if plot is None:
        plot = PlotGrid()
        plot.set_shape(n_features, 2)
        plot.show()
        for g in range(n_features):
            plot[g, 1].addLegend()

    for feature in range(n_features):
        if n_features > 1:
            current_feature = feature_list[feature]
            if median is True:
                mean = np.nanmedian(current_feature)
            else:
                mean = np.nanmean(current_feature)
            label = labels[feature]
            title = titles[feature]
        else:
            current_feature = feature_list
            mean = np.nanmean(current_feature)
            label = labels
            title = titles
        plot[feature, 0].setLabels(left=(label[0], label[1]))
        plot[feature, 0].hideAxis('bottom')
        plot[feature, 0].setTitle(title)
        if grand_trace is not None:
            plot[feature, 1].plot(grand_trace.time_values, grand_trace.data, pen=color, name=name)
        if len(current_feature) > 1:
            dx = pg.pseudoScatter(np.array(current_feature).astype(float), 0.7, bidir=True)
            #bar = pg.BarGraphItem(x=[i], height=mean, width=0.7, brush='w', pen={'color': color, 'width': 2})
            #plot[feature, 0].addItem(bar)
            plot[feature, 0].plot([i], [mean], symbol='o', symbolSize=20, symbolPen='k', symbolBrush=color)
            sem = stats.sem(current_feature, nan_policy='omit')
            #err = pg.ErrorBarItem(x=np.asarray([i]), y=np.asarray([mean]), height=sem, beam=0.1)
            #plot[feature, 0].addItem(err)
            plot[feature, 0].plot((0.3 * dx / dx.max()) + i, current_feature, pen=None, symbol='o', symbolSize=10, symbolPen='w',
                                symbolBrush=(color[0], color[1], color[2], 100))
        else:
            plot[feature, 0].plot([i], current_feature, pen=None, symbol='o', symbolSize=10, symbolPen='w',
                                symbolBrush=color)

    return plot
예제 #6
0
def pulse_average_matrix(expts, **kwds):
    results = {}
    types = ['sim1', 'tlx3', 'pvalb', 'sst', 'vip']
    plots = PlotGrid()
    plots.set_shape(len(types), len(types))
    indplots = PlotGrid()
    indplots.set_shape(len(types), len(types))
    plots.show()
    indplots.show()
    for i, pre_type in enumerate(types):
        for j, post_type in enumerate(types):
            avg_plot = plots[i, j]
            ind_plot = indplots[i, j]
            avg = plot_pulse_average(all_expts, pre_type, post_type, avg_plot, ind_plot, **kwds)
            pg.QtGui.QApplication.processEvents()
            results[(pre_type, post_type)] = avg

    return results
    def plot_model_results(self, model=None, fit=None):
        if model is None:
            if self._last_model_fit is None:
                raise Exception(
                    "Must run fit_release_model before plotting results.")
            model, fit = self._last_model_fit
        spike_sets = self.spike_sets

        rel_plots = PlotGrid()
        rel_plots.set_shape(2, 1)
        ind_plot = rel_plots[0, 0]
        ind_plot.setTitle('Release model fit - induction frequency')
        ind_plot.setLabels(bottom=('time', 's'), left='relative amplitude')
        rec_plot = rel_plots[1, 0]
        rec_plot.setTitle('Release model fit - recovery delay')
        rec_plot.setLabels(bottom=('time', 's'), left='relative amplitude')

        ind_plot.setLogMode(x=True, y=False)
        rec_plot.setLogMode(x=True, y=False)
        ind_plot.setXLink(rec_plot)

        for i, stim_params in enumerate(self.stim_param_order):
            x, y = spike_sets[i]
            output = model.eval(x, fit.values(), dt=0.5)
            y1 = output[:, 1]
            x1 = output[:, 0]
            if stim_params[1] - 0.250 < 5e-3:
                ind_plot.plot((x + 10) / 1000.,
                              y,
                              pen=None,
                              symbol='o',
                              symbolBrush=(i, 10))
                ind_plot.plot((x1 + 10) / 1000., y1, pen=(i, 10))
            if stim_params[0] == 50:
                rec_plot.plot((x + 10) / 1000.,
                              y,
                              pen=None,
                              symbol='o',
                              symbolBrush=(i, 10))
                rec_plot.plot((x1 + 10) / 1000., y1, pen=(i, 10))

        rel_plots.show()
        return rel_plots
def plot_response_averages(expt, show_baseline=False, **kwds):
    analyzer = MultiPatchExperimentAnalyzer.get(expt)
    devs = analyzer.list_devs()

    # First get average evoked responses for all pre/post pairs
    responses, rows, cols = analyzer.get_evoked_response_matrix(**kwds)

    # resize plot grid accordingly
    plots = PlotGrid()
    plots.set_shape(len(rows), len(cols))
    plots.show() 
    
    ranges = [([], []), ([], [])]
    points = []

    # Plot each matrix element with PSP fit
    for i, dev1 in enumerate(rows):
        for j, dev2 in enumerate(cols):
            # select plot and hide axes
            plt = plots[i, j]
            if i < len(devs) - 1:
                plt.getAxis('bottom').setVisible(False)
            if j > 0:
                plt.getAxis('left').setVisible(False)

            if dev1 == dev2:
                plt.getAxis('bottom').setVisible(False)
                plt.getAxis('left').setVisible(False)
                continue
            
            # adjust axes / labels
            plt.setXLink(plots[0, 0])
            plt.setYLink(plots[0, 0])
            plt.addLine(x=10e-3, pen=0.3)
            plt.addLine(y=0, pen=0.3)
            plt.setLabels(bottom=(str(dev2), 's'))
            if kwds.get('clamp_mode', 'ic') == 'ic':
                plt.setLabels(left=('%s' % dev1, 'V'))
            else:
                plt.setLabels(left=('%s' % dev1, 'A'))

            
            # print "==========", dev1, dev2
            avg_response = responses[(dev1, dev2)].bsub_mean()
            if avg_response is not None:
                avg_response.t0 = 0
                t = avg_response.time_values
                y = bessel_filter(Trace(avg_response.data, dt=avg_response.dt), 2e3).data
                plt.plot(t, y, antialias=True)

                # fit!                
                #fit = responses[(dev1, dev2)].fit_psp(yoffset=0, mask_stim_artifact=(abs(dev1-dev2) < 3))
                #lsnr = np.log(fit.snr)
                #lerr = np.log(fit.nrmse())
                
                #color = (
                    #np.clip(255 * (-lerr/3.), 0, 255),
                    #np.clip(50 * lsnr, 0, 255),
                    #np.clip(255 * (1+lerr/3.), 0, 255)
                #)

                #plt.plot(t, fit.best_fit, pen=color)
                ## plt.plot(t, fit.init_fit, pen='y')

                #points.append({'x': lerr, 'y': lsnr, 'brush': color})

                #if show_baseline:
                    ## plot baseline for reference
                    #bl = avg_response.meta['baseline'] - avg_response.meta['baseline_med']
                    #plt.plot(np.arange(len(bl)) * avg_response.dt, bl, pen=(0, 100, 0), antialias=True)

                # keep track of data range across all plots
                ranges[0][0].append(y.min())
                ranges[0][1].append(y.max())
                ranges[1][0].append(t[0])
                ranges[1][1].append(t[-1])

    plots[0,0].setYRange(min(ranges[0][0]), max(ranges[0][1]))
    plots[0,0].setXRange(min(ranges[1][0]), max(ranges[1][1]))

    # scatter plot of SNR vs NRMSE
    plt = pg.plot()
    plt.setLabels(left='ln(SNR)', bottom='ln(NRMSE)')
    plt.plot([p['x'] for p in points], [p['y'] for p in points], pen=None, symbol='o', symbolBrush=[pg.mkBrush(p['brush']) for p in points])
    # show threshold line
    line = pg.InfiniteLine(pos=[0, 6], angle=180/np.pi * np.arctan(1))
    plt.addItem(line, ignoreBounds=True)

    return plots
예제 #9
0
def distance_plot(connected,
                  distance,
                  plots=None,
                  color=(100, 100, 255),
                  size=10,
                  window=40e-6,
                  name=None,
                  fill_alpha=30):
    """Draw connectivity vs distance profiles with confidence intervals.
    
    Parameters
    ----------
    connected : boolean array
        Whether a synaptic connection was found for each probe
    distance : array
        Distance between cells for each probe
    plots : list of PlotWidget | PlotItem
        (optional) Two plots used to display distance profile and scatter plot.
    color : tuple
        (R, G, B) color values for line and confidence interval. The confidence interval
        will be drawn with reduced opacity (see *fill_alpha*)
    size: int
        size of scatter plot symbol
    window : float
        Width of distance window over which proportions are calculated for each point on
        the profile line.
    fill_alpha : int
        Opacity of confidence interval fill (0-255)

    Note: using a spacing value that is smaller than the window size may cause an
    otherwise smooth decrease over distance to instead look more like a series of downward steps.
    """
    color = pg.colorTuple(pg.mkColor(color))[:3]
    connected = np.array(connected).astype(float)
    distance = np.array(distance)

    # scatter plot connections probed
    if plots is None:
        grid = PlotGrid()
        grid.set_shape(2, 1)
        grid.grid.ci.layout.setRowStretchFactor(0, 5)
        grid.grid.ci.layout.setRowStretchFactor(1, 10)
        plots = (grid[1, 0], grid[0, 0])
        plots[0].grid = grid
        plots[0].addLegend()
        grid.show()
        plots[0].setLabels(bottom=('distance', 'm'),
                           left='connection probability')

    if plots[1] is not None:
        # scatter points a bit
        pts = np.vstack([distance, connected]).T
        conn = pts[:, 1] == 1
        unconn = pts[:, 1] == 0
        if np.any(conn):
            cscat = pg.pseudoScatter(pts[:, 0][conn],
                                     spacing=10e-6,
                                     bidir=False)
            mx = abs(cscat).max()
            if mx != 0:
                cscat = cscat * 0.2  # / mx
            pts[:, 1][conn] = -5e-5 - cscat
        if np.any(unconn):
            uscat = pg.pseudoScatter(pts[:, 0][unconn],
                                     spacing=10e-6,
                                     bidir=False)
            mx = abs(uscat).max()
            if mx != 0:
                uscat = uscat * 0.2  # / mx
            pts[:, 1][unconn] = uscat

        plots[1].setXLink(plots[0])
        plots[1].hideAxis('bottom')
        plots[1].hideAxis('left')

        color2 = color + (100, )
        scatter = plots[1].plot(pts[:, 0],
                                pts[:, 1],
                                pen=None,
                                symbol='o',
                                labels={'bottom': ('distance', 'm')},
                                size=size,
                                symbolBrush=color2,
                                symbolPen=None,
                                name=name)
        scatter.scatter.opts[
            'compositionMode'] = pg.QtGui.QPainter.CompositionMode_Plus

    # use a sliding window to plot the proportion of connections found along with a 95% confidence interval
    # for connection probability
    bin_edges = np.arange(0, 500e-6, window)
    xvals, prop, lower, upper = connectivity_profile(connected, distance,
                                                     bin_edges)

    # plot connection probability and confidence intervals
    color2 = [c / 3.0 for c in color]
    xvals = (xvals[:-1] + xvals[1:]) * 0.5
    mid_curve = plots[0].plot(xvals,
                              prop,
                              pen={
                                  'color': color,
                                  'width': 3
                              },
                              antialias=True,
                              name=name)
    upper_curve = plots[0].plot(xvals, upper, pen=(0, 0, 0, 0), antialias=True)
    lower_curve = plots[0].plot(xvals, lower, pen=(0, 0, 0, 0), antialias=True)
    upper_curve.setVisible(False)
    lower_curve.setVisible(False)
    color2 = color + (fill_alpha, )
    fill = pg.FillBetweenItem(upper_curve, lower_curve, brush=color2)
    fill.setZValue(-10)
    plots[0].addItem(fill, ignoreBounds=True)

    return plots, xvals, prop, upper, lower
예제 #10
0
holding = [-55, -75]
freqs = [10, 20, 50, 100]
rec_t = [250, 500, 1000, 2000, 4000]
sweep_threshold = 3
deconv = True

# cache_file = 'train_response_cache.pkl'
# response_cache = load_cache(cache_file)
# cache_change = []
# log_rec_plt = pg.plot()
# log_rec_plt.setLogMode(x=True)
qc_plot = pg.plot()
ind_plot = PlotGrid()
ind_plot.set_shape(4, len(connection_types))
ind_plot.show()
rec_plot = PlotGrid()
rec_plot.set_shape(5, len(connection_types))
rec_plot.show()
if deconv is True:
    deconv_ind_plot = PlotGrid()
    deconv_ind_plot.set_shape(4, len(connection_types))
    deconv_ind_plot.show()
    deconv_rec_plot = PlotGrid()
    deconv_rec_plot.set_shape(5, len(connection_types))
    deconv_rec_plot.show()
summary_plot = PlotGrid()
summary_plot.set_shape(len(connection_types), 2)
summary_plot.show()
symbols = ['o', 's', 'd', '+', 't']
trace_color = (0, 0, 0, 5)
# model.Dynamics = {'Dep':1, 'Fac':0, 'UR':0, 'SMR':0, 'DSR':0}
# params = {(('2/3', 'unknown'), ('2/3', 'unknown')): [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#           ((None,'rorb'), (None,'rorb')): [0, 506.7, 0, 0, 0, 0, 0.22, 0, 0, 0, 0, 0],
#           ((None,'sim1'), (None,'sim1')): [0, 1213.8, 0, 0, 0, 0, 0.17, 0, 0, 0, 0, 0],
#           ((None,'tlx3'), (None,'tlx3')): [0, 319.4, 0, 0, 0, 0, 0.16, 0, 0, 0, 0, 0]}

ind = data[0]
rec = data[1]
freq = 10
delta =250
order = human_connections.keys()
delay_order = [250, 500, 1000, 2000, 4000]

ind_plt = PlotGrid()
ind_plt.set_shape(4, 1)
ind_plt.show()
rec_plt = PlotGrid()
rec_plt.set_shape(1, 2)
rec_plt.show()
ind_plt_scatter = pg.plot()
ind_plt_all = PlotGrid()
ind_plt_all.set_shape(1, 2)
ind_plt_all.show()
ind_50 = {}
ninth_pulse_250 = {}
gain_plot = pg.plot()


for t, type in enumerate(order):
    if type == (('2', 'unknown'), ('2', 'unknown')):
        continue
예제 #12
0
parser.add_argument('--link-y-axis',
                    action='store_true',
                    default=False,
                    dest='link-y-axis',
                    help='link all y-axis down a column')

args = vars(parser.parse_args(sys.argv[1:]))
plot_sweeps = args['sweeps']
plot_trains = args['trains']
link_y_axis = args['link-y-axis']
expt_cache = 'C:/Users/Stephanies/multipatch_analysis/tools/expts_cache.pkl'
all_expts = ExperimentList(cache=expt_cache)

test = PlotGrid()
test.set_shape(len(connection_types.keys()), 1)
test.show()
grid = PlotGrid()
if plot_trains is True:
    grid.set_shape(len(connection_types.keys()), 3)
    grid[0, 1].setTitle(title='50 Hz Train')
    grid[0, 2].setTitle(title='Exponential Deconvolution')
    tau = 15e-3
    lp = 1000
else:
    grid.set_shape(len(connection_types.keys()), 2)
grid.show()
row = 0
grid[0, 0].setTitle(title='First Pulse')

maxYpulse = []
maxYtrain = []
예제 #13
0
def distance_plot(connected, distance, plots=None, color=(100, 100, 255), size=10, window=40e-6, spacing=None, name=None, fill_alpha=30):
    """Draw connectivity vs distance profiles with confidence intervals.
    
    Parameters
    ----------
    connected : boolean array
        Whether a synaptic connection was found for each probe
    distance : array
        Distance between cells for each probe
    plots : list of PlotWidget | PlotItem
        (optional) Two plots used to display distance profile and scatter plot.
    color : tuple
        (R, G, B) color values for line and confidence interval. The confidence interval
        will be drawn with alpha=100
    size: int
        size of scatter plot symbol
    window : float
        Width of distance window over which proportions are calculated for each point on
        the profile line.
    spacing : float
        Distance spacing between points on the profile line


    Note: using a spacing value that is smaller than the window size may cause an
    otherwise smooth decrease over distance to instead look more like a series of downward steps.
    """
    color = pg.colorTuple(pg.mkColor(color))[:3]
    connected = np.array(connected).astype(float)
    distance = np.array(distance)
    pts = np.vstack([distance, connected]).T
    
    # scatter points a bit
    conn = pts[:,1] == 1
    unconn = pts[:,1] == 0
    if np.any(conn):
        cscat = pg.pseudoScatter(pts[:,0][conn], spacing=10e-6, bidir=False)
        mx = abs(cscat).max()
        if mx != 0:
            cscat = cscat * 0.2# / mx
        pts[:,1][conn] = -5e-5 - cscat
    if np.any(unconn):
        uscat = pg.pseudoScatter(pts[:,0][unconn], spacing=10e-6, bidir=False)
        mx = abs(uscat).max()
        if mx != 0:
            uscat = uscat * 0.2# / mx
        pts[:,1][unconn] = uscat

    # scatter plot connections probed
    if plots is None:
        grid = PlotGrid()
        grid.set_shape(2, 1)
        grid.grid.ci.layout.setRowStretchFactor(0, 5)
        grid.grid.ci.layout.setRowStretchFactor(1, 10)
        plots = (grid[1,0], grid[0,0])
        plots[0].grid = grid
        plots[0].addLegend()
        grid.show()
        plots[0].setLabels(bottom=('distance', 'm'), left='connection probability')

    if plots[1] is not None:
        plots[1].setXLink(plots[0])
        plots[1].hideAxis('bottom')
        plots[1].hideAxis('left')

        color2 = color + (100,)
        scatter = plots[1].plot(pts[:,0], pts[:,1], pen=None, symbol='o', labels={'bottom': ('distance', 'm')}, size=size, symbolBrush=color2, symbolPen=None, name=name)
        scatter.scatter.opts['compositionMode'] = pg.QtGui.QPainter.CompositionMode_Plus

    # use a sliding window to plot the proportion of connections found along with a 95% confidence interval
    # for connection probability

    if spacing is None:
        spacing = window / 4.0
        
    xvals = np.arange(window / 2.0, 500e-6, spacing)
    upper = []
    lower = []
    prop = []
    ci_xvals = []
    for x in xvals:
        minx = x - window / 2.0
        maxx = x + window / 2.0
        # select points inside this window
        mask = (distance >= minx) & (distance <= maxx)
        pts_in_window = connected[mask]
        # compute stats for window
        n_probed = pts_in_window.shape[0]
        n_conn = pts_in_window.sum()
        if n_probed == 0:
            prop.append(np.nan)
        else:
            prop.append(n_conn / n_probed)
            ci = proportion_confint(n_conn, n_probed, method='beta')
            lower.append(ci[0])
            upper.append(ci[1])
            ci_xvals.append(x)

    # plot connection probability and confidence intervals
    color2 = [c / 3.0 for c in color]
    mid_curve = plots[0].plot(xvals, prop, pen={'color': color, 'width': 3}, antialias=True, name=name)
    upper_curve = plots[0].plot(ci_xvals, upper, pen=(0, 0, 0, 0), antialias=True)
    lower_curve = plots[0].plot(ci_xvals, lower, pen=(0, 0, 0, 0), antialias=True)
    upper_curve.setVisible(False)
    lower_curve.setVisible(False)
    color2 = color + (fill_alpha,)
    fill = pg.FillBetweenItem(upper_curve, lower_curve, brush=color2)
    fill.setZValue(-10)
    plots[0].addItem(fill, ignoreBounds=True)
    
    return plots, ci_xvals, prop, upper, lower
예제 #14
0
class PairView(QtGui.QWidget):
    """For analyzing pre/post-synaptic pairs.
    """
    def __init__(self, parent=None):
        self.sweeps = []
        self.channels = []

        self.current_event_set = None
        self.event_sets = []

        QtGui.QWidget.__init__(self, parent)

        self.layout = QtGui.QGridLayout()
        self.setLayout(self.layout)
        self.layout.setContentsMargins(0, 0, 0, 0)

        self.vsplit = QtGui.QSplitter(QtCore.Qt.Vertical)
        self.layout.addWidget(self.vsplit, 0, 0)
        
        self.pre_plot = pg.PlotWidget()
        self.post_plot = pg.PlotWidget()
        for plt in (self.pre_plot, self.post_plot):
            plt.setClipToView(True)
            plt.setDownsampling(True, True, 'peak')
        
        self.post_plot.setXLink(self.pre_plot)
        self.vsplit.addWidget(self.pre_plot)
        self.vsplit.addWidget(self.post_plot)
        
        self.response_plots = PlotGrid()
        self.vsplit.addWidget(self.response_plots)

        self.event_splitter = QtGui.QSplitter(QtCore.Qt.Horizontal)
        self.vsplit.addWidget(self.event_splitter)
        
        self.event_table = pg.TableWidget()
        self.event_splitter.addWidget(self.event_table)
        
        self.event_widget = QtGui.QWidget()
        self.event_widget_layout = QtGui.QGridLayout()
        self.event_widget.setLayout(self.event_widget_layout)
        self.event_splitter.addWidget(self.event_widget)
        
        self.event_splitter.setSizes([600, 100])
        
        self.event_set_list = QtGui.QListWidget()
        self.event_widget_layout.addWidget(self.event_set_list, 0, 0, 1, 2)
        self.event_set_list.addItem("current")
        self.event_set_list.itemSelectionChanged.connect(self.event_set_selected)
        
        self.add_set_btn = QtGui.QPushButton('add')
        self.event_widget_layout.addWidget(self.add_set_btn, 1, 0, 1, 1)
        self.add_set_btn.clicked.connect(self.add_set_clicked)
        self.remove_set_btn = QtGui.QPushButton('del')
        self.event_widget_layout.addWidget(self.remove_set_btn, 1, 1, 1, 1)
        self.remove_set_btn.clicked.connect(self.remove_set_clicked)
        self.fit_btn = QtGui.QPushButton('fit all')
        self.event_widget_layout.addWidget(self.fit_btn, 2, 0, 1, 2)
        self.fit_btn.clicked.connect(self.fit_clicked)

        self.fit_plot = PlotGrid()

        self.fit_explorer = None

        self.artifact_remover = ArtifactRemover(user_width=True)
        self.baseline_remover = BaselineRemover()
        self.filter = SignalFilter()
        
        self.params = pg.parametertree.Parameter(name='params', type='group', children=[
            {'name': 'pre', 'type': 'list', 'values': []},
            {'name': 'post', 'type': 'list', 'values': []},
            self.artifact_remover.params,
            self.baseline_remover.params,
            self.filter.params,
            {'name': 'time constant', 'type': 'float', 'suffix': 's', 'siPrefix': True, 'value': 10e-3, 'dec': True, 'minStep': 100e-6}
            
        ])
        self.params.sigTreeStateChanged.connect(self._update_plots)

    def data_selected(self, sweeps, channels):
        self.sweeps = sweeps
        self.channels = channels
        
        self.params.child('pre').setLimits(channels)
        self.params.child('post').setLimits(channels)
        
        self._update_plots()

    def _update_plots(self):
        sweeps = self.sweeps
        self.current_event_set = None
        self.event_table.clear()
        
        # clear all plots
        self.pre_plot.clear()
        self.post_plot.clear()

        pre = self.params['pre']
        post = self.params['post']
        
        # If there are no selected sweeps or channels have not been set, return
        if len(sweeps) == 0 or pre == post or pre not in self.channels or post not in self.channels:
            return

        pre_mode = sweeps[0][pre].clamp_mode
        post_mode = sweeps[0][post].clamp_mode
        for ch, mode, plot in [(pre, pre_mode, self.pre_plot), (post, post_mode, self.post_plot)]:
            units = 'A' if mode == 'vc' else 'V'
            plot.setLabels(left=("Channel %d" % ch, units), bottom=("Time", 's'))
        
        # Iterate over selected channels of all sweeps, plotting traces one at a time
        # Collect information about pulses and spikes
        pulses = []
        spikes = []
        post_traces = []
        for i,sweep in enumerate(sweeps):
            pre_trace = sweep[pre]['primary']
            post_trace = sweep[post]['primary']
            
            # Detect pulse times
            stim = sweep[pre]['command'].data
            sdiff = np.diff(stim)
            on_times = np.argwhere(sdiff > 0)[1:, 0]  # 1: skips test pulse
            off_times = np.argwhere(sdiff < 0)[1:, 0]
            pulses.append(on_times)

            # filter data
            post_filt = self.artifact_remover.process(post_trace, list(on_times) + list(off_times))
            post_filt = self.baseline_remover.process(post_filt)
            post_filt = self.filter.process(post_filt)
            post_traces.append(post_filt)
            
            # plot raw data
            color = pg.intColor(i, hues=len(sweeps)*1.3, sat=128)
            color.setAlpha(128)
            for trace, plot in [(pre_trace, self.pre_plot), (post_filt, self.post_plot)]:
                plot.plot(trace.time_values, trace.data, pen=color, antialias=False)

            # detect spike times
            spike_inds = []
            spike_info = []
            for on, off in zip(on_times, off_times):
                spike = detect_evoked_spike(sweep[pre], [on, off])
                spike_info.append(spike)
                if spike is None:
                    spike_inds.append(None)
                else:
                    spike_inds.append(spike['rise_index'])
            spikes.append(spike_info)
                    
            dt = pre_trace.dt
            vticks = pg.VTickGroup([x * dt for x in spike_inds if x is not None], yrange=[0.0, 0.2], pen=color)
            self.pre_plot.addItem(vticks)

        # Iterate over spikes, plotting average response
        all_responses = []
        avg_responses = []
        fits = []
        fit = None
        
        npulses = max(map(len, pulses))
        self.response_plots.clear()
        self.response_plots.set_shape(1, npulses+1) # 1 extra for global average
        self.response_plots.setYLink(self.response_plots[0,0])
        for i in range(1, npulses+1):
            self.response_plots[0,i].hideAxis('left')
        units = 'A' if post_mode == 'vc' else 'V'
        self.response_plots[0, 0].setLabels(left=("Averaged events (Channel %d)" % post, units))
        
        fit_pen = {'color':(30, 30, 255), 'width':2, 'dash': [1, 1]}
        for i in range(npulses):
            # get the chunk of each sweep between spikes
            responses = []
            all_responses.append(responses)
            for j, sweep in enumerate(sweeps):
                # get the current spike
                if i >= len(spikes[j]):
                    continue
                spike = spikes[j][i]
                if spike is None:
                    continue
                
                # find next spike
                next_spike = None
                for sp in spikes[j][i+1:]:
                    if sp is not None:
                        next_spike = sp
                        break
                    
                # determine time range for response
                max_len = int(40e-3 / dt)  # don't take more than 50ms for any response
                start = spike['rise_index']
                if next_spike is not None:
                    stop = min(start + max_len, next_spike['rise_index'])
                else:
                    stop = start + max_len
                    
                # collect data from this trace
                trace = post_traces[j]
                d = trace.data[start:stop].copy()
                responses.append(d)

            if len(responses) == 0:
                continue
                
            # extend all responses to the same length and take nanmean
            avg = ragged_mean(responses, method='clip')
            avg -= float_mode(avg[:int(1e-3/dt)])
            avg_responses.append(avg)
            
            # plot average response for this pulse
            start = np.median([sp[i]['rise_index'] for sp in spikes if sp[i] is not None]) * dt
            t = np.arange(len(avg)) * dt
            self.response_plots[0,i].plot(t, avg, pen='w', antialias=True)

            # fit!
            fit = self.fit_psp(avg, t, dt, post_mode)
            fits.append(fit)
            
            # let the user mess with this fit
            curve = self.response_plots[0,i].plot(t, fit.eval(), pen=fit_pen, antialias=True).curve
            curve.setClickable(True)
            curve.fit = fit
            curve.sigClicked.connect(self.fit_curve_clicked)
            
        # display global average
        global_avg = ragged_mean(avg_responses, method='clip')
        t = np.arange(len(global_avg)) * dt
        self.response_plots[0,-1].plot(t, global_avg, pen='w', antialias=True)
        global_fit = self.fit_psp(global_avg, t, dt, post_mode)
        self.response_plots[0,-1].plot(t, global_fit.eval(), pen=fit_pen, antialias=True)
            
        # display fit parameters in table
        events = []
        for i,f in enumerate(fits + [global_fit]):
            if f is None:
                continue
            if i >= len(fits):
                vals = OrderedDict([('id', 'avg'), ('spike_time', np.nan), ('spike_stdev', np.nan)])
            else:
                spt = [s[i]['peak_index'] * dt for s in spikes if s[i] is not None]
                vals = OrderedDict([('id', i), ('spike_time', np.mean(spt)), ('spike_stdev', np.std(spt))])
            vals.update(OrderedDict([(k,f.best_values[k]) for k in f.params.keys()]))
            events.append(vals)
            
        self.current_event_set = (pre, post, events, sweeps)
        self.event_set_list.setCurrentRow(0)
        self.event_set_selected()

    def fit_psp(self, data, t, dt, clamp_mode):
        mode = float_mode(data[:int(1e-3/dt)])
        sign = -1 if data.mean() - mode < 0 else 1
        params = OrderedDict([
            ('xoffset', (2e-3, 3e-4, 5e-3)),
            ('yoffset', data[0]),
            ('amp', sign * 10e-12),
            #('k', (2e-3, 50e-6, 10e-3)),
            ('rise_time', (2e-3, 50e-6, 10e-3)),
            ('decay_tau', (4e-3, 500e-6, 50e-3)),
            ('rise_power', (2.0, 'fixed')),
        ])
        if clamp_mode == 'ic':
            params['amp'] = sign * 10e-3
            #params['k'] = (5e-3, 50e-6, 20e-3)
            params['rise_time'] = (5e-3, 50e-6, 20e-3)
            params['decay_tau'] = (15e-3, 500e-6, 150e-3)
        
        fit_kws = {'xtol': 1e-3, 'maxfev': 100}
        
        psp = fitting.Psp()
        return psp.fit(data, x=t, fit_kws=fit_kws, params=params)

    def add_set_clicked(self):
        if self.current_event_set is None:
            return
        ces = self.current_event_set
        self.event_sets.append(ces)
        item = QtGui.QListWidgetItem("%d -> %d" % (ces[0], ces[1]))
        self.event_set_list.addItem(item)
        item.event_set = ces
        
    def remove_set_clicked(self):
        if self.event_set_list.currentRow() == 0:
            return
        sel = self.event_set_list.takeItem(self.event_set_list.currentRow())
        self.event_sets.remove(sel.event_set)
        
    def event_set_selected(self):
        sel = self.event_set_list.selectedItems()[0]
        if sel.text() == "current":
            self.event_table.setData(self.current_event_set[2])
        else:
            self.event_table.setData(sel.event_set[2])
    
    def fit_clicked(self):
        from synaptic_release import ReleaseModel
        
        self.fit_plot.clear()
        self.fit_plot.show()
        
        n_sets = len(self.event_sets)
        self.fit_plot.set_shape(n_sets, 1)
        l = self.fit_plot[0, 0].legend
        if l is not None:
            l.scene.removeItem(l)
            self.fit_plot[0, 0].legend = None
        self.fit_plot[0, 0].addLegend()
        
        for i in range(n_sets):
            self.fit_plot[i,0].setXLink(self.fit_plot[0, 0])
        
        spike_sets = []
        for i,evset in enumerate(self.event_sets):
            evset = evset[2][:-1]  # select events, skip last row (average)
            x = np.array([ev['spike_time'] for ev in evset])
            y = np.array([ev['amp'] for ev in evset])
            x -= x[0]
            x *= 1000
            y /= y[0]
            spike_sets.append((x, y))
            
            self.fit_plot[i,0].plot(x/1000., y, pen=None, symbol='o')
            
        model = ReleaseModel()
        dynamics_types = ['Dep', 'Fac', 'UR', 'SMR', 'DSR']
        for k in dynamics_types:
            model.Dynamics[k] = 0
        
        fit_params = []
        with pg.ProgressDialog("Fitting release model..", 0, len(dynamics_types)) as dlg:
            for k in dynamics_types:
                model.Dynamics[k] = 1
                fit_params.append(model.run_fit(spike_sets))
                dlg += 1
                if dlg.wasCanceled():
                    return
        
        max_color = len(fit_params)*1.5

        for i,params in enumerate(fit_params):
            for j,spikes in enumerate(spike_sets):
                x, y = spikes
                t = np.linspace(0, x.max(), 1000)
                output = model.eval(x, params.values())
                y = output[:,1]
                x = output[:,0]/1000.
                self.fit_plot[j,0].plot(x, y, pen=(i,max_color), name=dynamics_types[i])

    def fit_curve_clicked(self, curve):
        if self.fit_explorer is None:
            self.fit_explorer = FitExplorer(curve.fit)
        else:
            self.fit_explorer.set_fit(curve.fit)
        self.fit_explorer.show()
def summary_plot_pulse(feature_list, labels, titles, i, median=False, grand_trace=None, plot=None, color=None, name=None):
    """
    Plots features of single-pulse responses such as amplitude, latency, etc. for group analysis. Can be used for one
    group by ideal for comparing across many groups in the feature_list

    Parameters
    ----------
    feature_list : list of lists of floats
        single-pulse features such as amplitude. Can be multiple features each a list themselves
    labels : list of pyqtgraph.LabelItem
        axis labels, must be a list of same length as feature_list
    titles : list of strings
        plot title, must be a list of same length as feature_list
    i : integer
        iterator to place groups along x-axis
    median : boolean
        to calculate median (True) vs mean (False), default is False
    grand_trace : neuroanalysis.data.TraceView object
        option to plot response trace alongside scatter plot, default is None
    plot : pyqtgraph.PlotItem
        If not None, plot the data on the referenced pyqtgraph object.
    color : tuple
        plot color
    name : pyqtgraph.LegendItem

    Returns
    -------
    plot : pyqtgraph.PlotItem
        2 x n plot with scatter plot and optional trace response plot for each feature (n)

    """
    if type(feature_list) is tuple:
        n_features = len(feature_list)
    else:
        n_features = 1
    if plot is None:
        plot = PlotGrid()
        plot.set_shape(n_features, 2)
        plot.show()
        for g in range(n_features):
            plot[g, 1].addLegend()

    for feature in range(n_features):
        if n_features > 1:
            current_feature = feature_list[feature]
            if median is True:
                mean = np.nanmedian(current_feature)
            else:
                mean = np.nanmean(current_feature)
            label = labels[feature]
            title = titles[feature]
        else:
            current_feature = feature_list
            mean = np.nanmean(current_feature)
            label = labels
            title = titles
        plot[feature, 0].setLabels(left=(label[0], label[1]))
        plot[feature, 0].hideAxis('bottom')
        plot[feature, 0].setTitle(title)
        if grand_trace is not None:
            plot[feature, 1].plot(grand_trace.time_values, grand_trace.data, pen=color, name=name)
        if len(current_feature) > 1:
            dx = pg.pseudoScatter(np.array(current_feature).astype(float), 0.7, bidir=True)
            plot[feature, 0].plot([i], [mean], symbol='o', symbolSize=20, symbolPen='k', symbolBrush=color)
            sem = stats.sem(current_feature, nan_policy='omit')
            if len(color) != 3:
                new_color = pg.glColor(color)
                color = (new_color[0]*255, new_color[1]*255, new_color[2]*255)
            plot[feature, 0].plot((0.3 * dx / dx.max()) + i, current_feature, pen=None, symbol='o', symbolSize=10, symbolPen='w',
                                symbolBrush=(color[0], color[1], color[2], 100))
        else:
            plot[feature, 0].plot([i], current_feature, pen=None, symbol='o', symbolSize=10, symbolPen='w',
                                symbolBrush=color)

    return plot
    def plot_train_responses(self, plot_grid=None):
        """
        Plot individual and averaged train responses for each set of stimulus parameters.

        Return a new PlotGrid.
        """
        train_responses = self.train_responses

        if plot_grid is None:
            train_plots = PlotGrid()
        else:
            train_plots = plot_grid
        train_plots.set_shape(len(self.stim_param_order), 2)

        for i, stim_params in enumerate(train_responses.keys()):

            # Collect and plot average traces covering the induction and recovery
            # periods for this set of stim params
            ind_group = train_responses[stim_params][0]
            rec_group = train_responses[stim_params][1]

            for j in range(len(ind_group)):
                ind = ind_group.responses[j]
                rec = rec_group.responses[j]
                base = np.median(ind_group.baselines[j].data)
                train_plots[i, 0].plot(ind.time_values,
                                       ind.data - base,
                                       pen=(128, 128, 128, 100))
                train_plots[i, 1].plot(rec.time_values,
                                       rec.data - base,
                                       pen=(128, 128, 128, 100))
            ind_avg = ind_group.bsub_mean()
            rec_avg = rec_group.bsub_mean()

            ind_freq, rec_delay, holding = stim_params
            rec_delay = np.round(rec_delay, 2)
            train_plots[i, 0].plot(ind_avg.time_values,
                                   ind_avg.data,
                                   pen='g',
                                   antialias=True)
            train_plots[i, 1].plot(rec_avg.time_values,
                                   rec_avg.data,
                                   pen='g',
                                   antialias=True)
            train_plots[i, 0].setLabels(left=('Vm', 'V'))
            label = pg.LabelItem("ind: %0.0f  rec: %0.0f  hold: %0.0f" %
                                 (ind_freq, rec_delay * 1000, holding * 1000))
            label.setParentItem(train_plots[i, 0].vb)
            train_plots[i, 0].label = label

        train_plots.show()
        train_plots.setYLink(train_plots[0, 0])
        for i in range(train_plots.shape[0]):
            train_plots[i, 0].setXLink(train_plots[0, 0])
            train_plots[i, 1].setXLink(train_plots[0, 1])
        train_plots.grid.ci.layout.setColumnStretchFactor(0, 3)
        train_plots.grid.ci.layout.setColumnStretchFactor(1, 2)
        train_plots.setClipToView(False)  # has a bug :(
        train_plots.setDownsampling(True, True, 'peak')

        return train_plots
예제 #17
0
def save_fit_psp_test_set():
    """NOTE THIS CODE DOES NOT WORK BUT IS HERE FOR DOCUMENTATION PURPOSES SO 
    THAT WE CAN TRACE BACK HOW THE TEST DATA WAS CREATED IF NEEDED.
    Create a test set of data for testing the fit_psp function.  Uses Steph's 
    original first_puls_feature.py code to filter out error causing data.
    
    Example run statement
    python save save_fit_psp_test_set.py --organism mouse --connection ee
    
    Comment in the code that does the saving at the bottom
    """

    import pyqtgraph as pg
    import numpy as np
    import csv
    import sys
    import argparse
    from multipatch_analysis.experiment_list import cached_experiments
    from manuscript_figures import get_response, get_amplitude, response_filter, feature_anova, write_cache, trace_plot, \
        colors_human, colors_mouse, fail_rate, pulse_qc, feature_kw
    from synapse_comparison import load_cache, summary_plot_pulse
    from neuroanalysis.data import TraceList, Trace
    from neuroanalysis.ui.plot_grid import PlotGrid
    from multipatch_analysis.connection_detection import fit_psp
    from rep_connections import ee_connections, human_connections, no_include, all_connections, ie_connections, ii_connections, ei_connections
    from multipatch_analysis.synaptic_dynamics import DynamicsAnalyzer
    from scipy import stats
    import time
    import pandas as pd
    import json
    import os

    app = pg.mkQApp()
    pg.dbg()
    pg.setConfigOption('background', 'w')
    pg.setConfigOption('foreground', 'k')

    parser = argparse.ArgumentParser(
        description=
        'Enter organism and type of connection you"d like to analyze ex: mouse ee (all mouse excitatory-'
        'excitatory). Alternatively enter a cre-type connection ex: sim1-sim1')
    parser.add_argument('--organism',
                        dest='organism',
                        help='Select mouse or human')
    parser.add_argument('--connection',
                        dest='connection',
                        help='Specify connections to analyze')
    args = vars(parser.parse_args(sys.argv[1:]))

    all_expts = cached_experiments()
    manifest = {
        'Type': [],
        'Connection': [],
        'amp': [],
        'latency': [],
        'rise': [],
        'rise2080': [],
        'rise1090': [],
        'rise1080': [],
        'decay': [],
        'nrmse': [],
        'CV': []
    }
    fit_qc = {'nrmse': 8, 'decay': 499e-3}

    if args['organism'] == 'mouse':
        color_palette = colors_mouse
        calcium = 'high'
        age = '40-60'
        sweep_threshold = 3
        threshold = 0.03e-3
        connection = args['connection']
        if connection == 'ee':
            connection_types = ee_connections.keys()
        elif connection == 'ii':
            connection_types = ii_connections.keys()
        elif connection == 'ei':
            connection_types = ei_connections.keys()
        elif connection == 'ie':
            connection_types == ie_connections.keys()
        elif connection == 'all':
            connection_types = all_connections.keys()
        elif len(connection.split('-')) == 2:
            c_type = connection.split('-')
            if c_type[0] == '2/3':
                pre_type = ('2/3', 'unknown')
            else:
                pre_type = (None, c_type[0])
            if c_type[1] == '2/3':
                post_type = ('2/3', 'unknown')
            else:
                post_type = (None, c_type[0])
            connection_types = [(pre_type, post_type)]
    elif args['organism'] == 'human':
        color_palette = colors_human
        calcium = None
        age = None
        sweep_threshold = 5
        threshold = None
        connection = args['connection']
        if connection == 'ee':
            connection_types = human_connections.keys()
        else:
            c_type = connection.split('-')
            connection_types = [((c_type[0], 'unknown'), (c_type[1],
                                                          'unknown'))]

    plt = pg.plot()

    scale_offset = (-20, -20)
    scale_anchor = (0.4, 1)
    holding = [-65, -75]
    qc_plot = pg.plot()
    grand_response = {}
    expt_ids = {}
    feature_plot = None
    feature2_plot = PlotGrid()
    feature2_plot.set_shape(5, 1)
    feature2_plot.show()
    feature3_plot = PlotGrid()
    feature3_plot.set_shape(1, 3)
    feature3_plot.show()
    amp_plot = pg.plot()
    synapse_plot = PlotGrid()
    synapse_plot.set_shape(len(connection_types), 1)
    synapse_plot.show()
    for c in range(len(connection_types)):
        cre_type = (connection_types[c][0][1], connection_types[c][1][1])
        target_layer = (connection_types[c][0][0], connection_types[c][1][0])
        conn_type = connection_types[c]
        expt_list = all_expts.select(cre_type=cre_type,
                                     target_layer=target_layer,
                                     calcium=calcium,
                                     age=age)
        color = color_palette[c]
        grand_response[conn_type[0]] = {
            'trace': [],
            'amp': [],
            'latency': [],
            'rise': [],
            'dist': [],
            'decay': [],
            'CV': [],
            'amp_measured': []
        }
        expt_ids[conn_type[0]] = []
        synapse_plot[c, 0].addLegend()
        for expt in expt_list:
            for pre, post in expt.connections:
                if [expt.uid, pre, post] in no_include:
                    continue
                cre_check = expt.cells[pre].cre_type == cre_type[
                    0] and expt.cells[post].cre_type == cre_type[1]
                layer_check = expt.cells[pre].target_layer == target_layer[
                    0] and expt.cells[post].target_layer == target_layer[1]
                if cre_check is True and layer_check is True:
                    pulse_response, artifact = get_response(
                        expt, pre, post, analysis_type='pulse')
                    if threshold is not None and artifact > threshold:
                        continue
                    response_subset, hold = response_filter(
                        pulse_response,
                        freq_range=[0, 50],
                        holding_range=holding,
                        pulse=True)
                    if len(response_subset) >= sweep_threshold:
                        qc_plot.clear()
                        qc_list = pulse_qc(response_subset,
                                           baseline=1.5,
                                           pulse=None,
                                           plot=qc_plot)
                        if len(qc_list) >= sweep_threshold:
                            avg_trace, avg_amp, amp_sign, peak_t = get_amplitude(
                                qc_list)
                            #                        if amp_sign is '-':
                            #                            continue
                            #                        #print ('%s, %0.0f' %((expt.uid, pre, post), hold, ))
                            #                        all_amps = fail_rate(response_subset, '+', peak_t)
                            #                        cv = np.std(all_amps)/np.mean(all_amps)
                            #
                            #                        # weight parts of the trace during fitting
                            dt = avg_trace.dt
                            weight = np.ones(
                                len(avg_trace.data
                                    )) * 10.  #set everything to ten initially
                            weight[int(10e-3 / dt):int(
                                12e-3 / dt)] = 0.  #area around stim artifact
                            weight[int(12e-3 / dt):int(
                                19e-3 / dt)] = 30.  #area around steep PSP rise

                            # check if the test data dir is there and if not create it
                            test_data_dir = 'test_psp_fit'
                            if not os.path.isdir(test_data_dir):
                                os.mkdir(test_data_dir)

                            save_dict = {}
                            save_dict['input'] = {
                                'data': avg_trace.data.tolist(),
                                'dtype': str(avg_trace.data.dtype),
                                'dt': float(avg_trace.dt),
                                'amp_sign': amp_sign,
                                'yoffset': 0,
                                'xoffset': 14e-3,
                                'avg_amp': float(avg_amp),
                                'method': 'leastsq',
                                'stacked': False,
                                'rise_time_mult_factor': 10.,
                                'weight': weight.tolist()
                            }

                            # need to remake trace because different output is created
                            avg_trace_simple = Trace(
                                data=np.array(save_dict['input']['data']),
                                dt=save_dict['input']
                                ['dt'])  # create Trace object

                            psp_fits_original = fit_psp(
                                avg_trace,
                                sign=save_dict['input']['amp_sign'],
                                yoffset=save_dict['input']['yoffset'],
                                xoffset=save_dict['input']['xoffset'],
                                amp=save_dict['input']['avg_amp'],
                                method=save_dict['input']['method'],
                                stacked=save_dict['input']['stacked'],
                                rise_time_mult_factor=save_dict['input']
                                ['rise_time_mult_factor'],
                                fit_kws={
                                    'weights': save_dict['input']['weight']
                                })

                            psp_fits_simple = fit_psp(
                                avg_trace_simple,
                                sign=save_dict['input']['amp_sign'],
                                yoffset=save_dict['input']['yoffset'],
                                xoffset=save_dict['input']['xoffset'],
                                amp=save_dict['input']['avg_amp'],
                                method=save_dict['input']['method'],
                                stacked=save_dict['input']['stacked'],
                                rise_time_mult_factor=save_dict['input']
                                ['rise_time_mult_factor'],
                                fit_kws={
                                    'weights': save_dict['input']['weight']
                                })
                            print expt.uid, pre, post
                            if psp_fits_original.nrmse(
                            ) != psp_fits_simple.nrmse():
                                print '  the nrmse values dont match'
                                print '\toriginal', psp_fits_original.nrmse()
                                print '\tsimple', psp_fits_simple.nrmse()
def summary_plot_pulse(feature_list,
                       labels,
                       titles,
                       i,
                       median=False,
                       grand_trace=None,
                       plot=None,
                       color=None,
                       name=None):
    """
    Plots features of single-pulse responses such as amplitude, latency, etc. for group analysis. Can be used for one
    group by ideal for comparing across many groups in the feature_list

    Parameters
    ----------
    feature_list : list of lists of floats
        single-pulse features such as amplitude. Can be multiple features each a list themselves
    labels : list of pyqtgraph.LabelItem
        axis labels, must be a list of same length as feature_list
    titles : list of strings
        plot title, must be a list of same length as feature_list
    i : integer
        iterator to place groups along x-axis
    median : boolean
        to calculate median (True) vs mean (False), default is False
    grand_trace : neuroanalysis.data.TSeriesView object
        option to plot response trace alongside scatter plot, default is None
    plot : pyqtgraph.PlotItem
        If not None, plot the data on the referenced pyqtgraph object.
    color : tuple
        plot color
    name : pyqtgraph.LegendItem

    Returns
    -------
    plot : pyqtgraph.PlotItem
        2 x n plot with scatter plot and optional trace response plot for each feature (n)

    """
    if type(feature_list) is tuple:
        n_features = len(feature_list)
    else:
        n_features = 1
    if plot is None:
        plot = PlotGrid()
        plot.set_shape(n_features, 2)
        plot.show()
        for g in range(n_features):
            plot[g, 1].addLegend()

    for feature in range(n_features):
        if n_features > 1:
            current_feature = feature_list[feature]
            if median is True:
                mean = np.nanmedian(current_feature)
            else:
                mean = np.nanmean(current_feature)
            label = labels[feature]
            title = titles[feature]
        else:
            current_feature = feature_list
            mean = np.nanmean(current_feature)
            label = labels
            title = titles
        plot[feature, 0].setLabels(left=(label[0], label[1]))
        plot[feature, 0].hideAxis('bottom')
        plot[feature, 0].setTitle(title)
        if grand_trace is not None:
            plot[feature, 1].plot(grand_trace.time_values,
                                  grand_trace.data,
                                  pen=color,
                                  name=name)
        if len(current_feature) > 1:
            dx = pg.pseudoScatter(np.array(current_feature).astype(float),
                                  0.7,
                                  bidir=True)
            plot[feature, 0].plot([i], [mean],
                                  symbol='o',
                                  symbolSize=20,
                                  symbolPen='k',
                                  symbolBrush=color)
            sem = stats.sem(current_feature, nan_policy='omit')
            if len(color) != 3:
                new_color = pg.glColor(color)
                color = (new_color[0] * 255, new_color[1] * 255,
                         new_color[2] * 255)
            plot[feature, 0].plot(
                (0.3 * dx / dx.max()) + i,
                current_feature,
                pen=None,
                symbol='o',
                symbolSize=10,
                symbolPen='w',
                symbolBrush=(color[0], color[1], color[2], 100))
        else:
            plot[feature, 0].plot([i],
                                  current_feature,
                                  pen=None,
                                  symbol='o',
                                  symbolSize=10,
                                  symbolPen='w',
                                  symbolBrush=color)

    return plot
def summary_plot_pulse(feature_list,
                       feature_mean,
                       labels,
                       titles,
                       i,
                       grand_trace=None,
                       plot=None,
                       color=None,
                       name=None):
    if type(feature_list) is tuple:
        n_features = len(feature_list)
    else:
        n_features = 1
    if plot is None:
        plot = PlotGrid()
        plot.set_shape(n_features, 2)
        plot.show()
        for g in range(n_features):
            plot[g, 1].addLegend()
            plot[g, 1].setLabels(left=('Vm', 'V'))
            plot[g, 1].setLabels(bottom=('t', 's'))

    for feature in range(n_features):
        if n_features > 1:
            features = feature_list[feature]
            mean = feature_mean[feature]
            label = labels[feature]
            title = titles[feature]
        else:
            features = feature_list
            mean = feature_mean
            label = labels
            title = titles
        plot[feature, 0].setLabels(left=(label[0], label[1]))
        plot[feature, 0].hideAxis('bottom')
        plot[feature, 0].setTitle(title)
        if grand_trace is not None:
            plot[feature, 1].plot(grand_trace.time_values,
                                  grand_trace.data,
                                  pen=color,
                                  name=name)
        if len(features) > 1:
            dx = pg.pseudoScatter(np.array(features).astype(float),
                                  0.3,
                                  bidir=True)
            bar = pg.BarGraphItem(x=[i],
                                  height=mean,
                                  width=0.7,
                                  brush='w',
                                  pen={
                                      'color': color,
                                      'width': 2
                                  })
            plot[feature, 0].addItem(bar)
            sem = stats.sem(features)
            err = pg.ErrorBarItem(x=np.asarray([i]),
                                  y=np.asarray([mean]),
                                  height=sem,
                                  beam=0.3)
            plot[feature, 0].addItem(err)
            plot[feature, 0].plot((0.3 * dx / dx.max()) + i,
                                  features,
                                  pen=None,
                                  symbol='o',
                                  symbolSize=10,
                                  symbolPen='w',
                                  symbolBrush=color)
        else:
            plot[feature, 0].plot([i],
                                  features,
                                  pen=None,
                                  symbol='o',
                                  symbolSize=10,
                                  symbolPen='w',
                                  symbolBrush=color)

    return plot
rms = np.array([row[0] for row in rows])
rig = np.array([row[1] for row in rows]).astype(int)
hs = np.array([row[2] for row in rows]).astype(int)
col = rig*8 + hs

ts = np.array([time.mktime(row[3].timetuple()) for row in rows])
#ts -= ts[0]

pg.plot(col + np.random.uniform(size=len(col))*0.7, rms, pen=None, symbol='o', symbolPen=None, symbolBrush=(255, 255, 255, 50))

plt = pg.plot(labels={'left': 'number of sweeps (normalized per rig)', 'bottom': ('baseline rms error', 'V')})
plt.addLegend()

grid = PlotGrid()
grid.set_shape(3, 1)
grid.show()

for r, c in ((1, 'r'), (2, 'g'), (3, 'b')):
    mask = rig==r
    rig_data = rms[mask]
    y, x = np.histogram(rig_data, bins=np.linspace(0, 0.002, 1000))
    plt.plot(x, y/len(rig_data), stepMode=True, connect='finite', pen=c, name="Rig %d" % r)

    p = grid[r-1, 0]
    p.plot(ts[mask], rig_data, pen=None, symbol='o', symbolPen=None, symbolBrush=(255, 255, 255, 100))
    p.setLabels(left=('rig %d baseline rms noise'%r, 'V'))

grid.setXLink(grid[0, 0])
grid.setYLink(grid[0, 0])

class DistancePlot(object):
    def __init__(self):
        self.grid = PlotGrid()
        self.grid.set_shape(2, 1)
        self.grid.grid.ci.layout.setRowStretchFactor(0, 5)
        self.grid.grid.ci.layout.setRowStretchFactor(1, 10)
        self.plots = (self.grid[1,0], self.grid[0,0])
        self.plots[0].grid = self.grid
        self.plots[0].addLegend()
        self.grid.show()
        self.plots[0].setLabels(bottom=('distance', 'm'), left='connection probability')
        self.params = Parameter.create(name='Distance binning window', type='float', value=40.e-6, step=10.e-6, suffix='m', siPrefix=True)
        self.element_plot = None
        self.elements = []
        self.element_colors = []
        self.results = None
        self.color = None
        self.name = None

        self.params.sigTreeStateChanged.connect(self.update_plot)

    def plot_distance(self, results, color, name, size=10):
        """Results needs to be a DataFrame or Series object with 'connected' and 'distance' as columns

        """
        if self.results is None:
            self.name = name
            self.color = color
            self.results = results
        connected = results['connected']
        distance = results['distance'] 
        dist_win = self.params.value()
        self.dist_plot = distance_plot(connected, distance, plots=self.plots, color=color, name=name, size=size, window=dist_win, spacing=dist_win)
        self.plots[0].setXRange(0, 200e-6)
        # 
        return self.dist_plot

    def invalidate_output(self):
        self.grid.clear()

    def element_distance(self, element, color, add_to_list=True):
        if add_to_list is True:
            self.element_colors.append(color)
            self.elements.append(element)
        pre = element['pre_class'][0].name
        post = element['post_class'][0].name
        name = ('%s->%s' % (pre, post))
        self.element_plot = self.plot_distance(element, color=color, name=name, size=15)

    def element_distance_reset(self, results, color, name):
        self.elements = []
        self.element_colors = []
        self.grid.clear()
        self.dist_plot = self.plot_distance(results, color=color, name=name, size=10)

    def update_plot(self):
        self.invalidate_output()
        self.plot_distance(self.results, self.color, self.name)
        if self.element_plot is not None:
            for element, color in zip(self.elements, self.element_colors):
                self.element_distance(element, color, add_to_list=False)
예제 #22
0
        rms,
        pen=None,
        symbol='o',
        symbolPen=None,
        symbolBrush=(255, 255, 255, 50))

plt = pg.plot(
    labels={
        'left': 'number of sweeps (normalized per rig)',
        'bottom': ('baseline rms error', 'V')
    })
plt.addLegend()

grid = PlotGrid()
grid.set_shape(3, 1)
grid.show()

for r, c in ((1, 'r'), (2, 'g'), (3, 'b')):
    mask = rig == r
    rig_data = rms[mask]
    y, x = np.histogram(rig_data, bins=np.linspace(0, 0.002, 1000))
    plt.plot(x,
             y / len(rig_data),
             stepMode=True,
             connect='finite',
             pen=c,
             name="Rig %d" % r)

    p = grid[r - 1, 0]
    p.plot(ts[mask],
           rig_data,
예제 #23
0
def distance_plot(connected,
                  distance,
                  plots=None,
                  color=(100, 100, 255),
                  window=40e-6,
                  spacing=None,
                  name=None,
                  fill_alpha=30):
    """Draw connectivity vs distance profiles with confidence intervals.
    
    Parameters
    ----------
    connected : boolean array
        Whether a synaptic connection was found for each probe
    distance : array
        Distance between cells for each probe
    plots : list of PlotWidget | PlotItem
        (optional) Two plots used to display distance profile and scatter plot.
    color : tuple
        (R, G, B) color values for line and confidence interval. The confidence interval
        will be drawn with alpha=100
    window : float
        Width of distance window over which proportions are calculated for each point on
        the profile line.
    spacing : float
        Distance spacing between points on the profile line


    Note: using a spacing value that is smaller than the window size may cause an
    otherwise smooth decrease over distance to instead look more like a series of downward steps.
    """
    color = pg.colorTuple(pg.mkColor(color))[:3]
    connected = np.array(connected).astype(float)
    distance = np.array(distance)
    pts = np.vstack([distance, connected]).T

    # scatter points a bit
    conn = pts[:, 1] == 1
    unconn = pts[:, 1] == 0
    if np.any(conn):
        cscat = pg.pseudoScatter(pts[:, 0][conn], spacing=10e-6, bidir=False)
        mx = abs(cscat).max()
        if mx != 0:
            cscat = cscat * 0.2  # / mx
        pts[:, 1][conn] = -2e-5 - cscat
    if np.any(unconn):
        uscat = pg.pseudoScatter(pts[:, 0][unconn], spacing=10e-6, bidir=False)
        mx = abs(uscat).max()
        if mx != 0:
            uscat = uscat * 0.2  # / mx
        pts[:, 1][unconn] = uscat

    # scatter plot connections probed
    if plots is None:
        grid = PlotGrid()
        grid.set_shape(2, 1)
        grid.grid.ci.layout.setRowStretchFactor(0, 5)
        grid.grid.ci.layout.setRowStretchFactor(1, 10)
        plots = (grid[1, 0], grid[0, 0])
        plots[0].grid = grid
        plots[0].addLegend()
        grid.show()
    plots[0].setLabels(bottom=('distance', 'm'), left='connection probability')

    if plots[1] is not None:
        plots[1].setXLink(plots[0])
        plots[1].hideAxis('bottom')
        plots[1].hideAxis('left')

        color2 = color + (100, )
        scatter = plots[1].plot(pts[:, 0],
                                pts[:, 1],
                                pen=None,
                                symbol='o',
                                labels={'bottom': ('distance', 'm')},
                                symbolBrush=color2,
                                symbolPen=None,
                                name=name)
        scatter.scatter.opts[
            'compositionMode'] = pg.QtGui.QPainter.CompositionMode_Plus

    # use a sliding window to plot the proportion of connections found along with a 95% confidence interval
    # for connection probability

    if spacing is None:
        spacing = window / 4.0

    xvals = np.arange(window / 2.0, 500e-6, spacing)
    upper = []
    lower = []
    prop = []
    ci_xvals = []
    for x in xvals:
        minx = x - window / 2.0
        maxx = x + window / 2.0
        # select points inside this window
        mask = (distance >= minx) & (distance <= maxx)
        pts_in_window = connected[mask]
        # compute stats for window
        n_probed = pts_in_window.shape[0]
        n_conn = pts_in_window.sum()
        if n_probed == 0:
            prop.append(np.nan)
        else:
            prop.append(n_conn / n_probed)
            ci = binomial_ci(n_conn, n_probed)
            lower.append(ci[0])
            upper.append(ci[1])
            ci_xvals.append(x)

    # plot connection probability and confidence intervals
    color2 = [c / 3.0 for c in color]
    mid_curve = plots[0].plot(xvals,
                              prop,
                              pen={
                                  'color': color,
                                  'width': 3
                              },
                              antialias=True,
                              name=name)
    upper_curve = plots[0].plot(ci_xvals,
                                upper,
                                pen=(0, 0, 0, 0),
                                antialias=True)
    lower_curve = plots[0].plot(ci_xvals,
                                lower,
                                pen=(0, 0, 0, 0),
                                antialias=True)
    upper_curve.setVisible(False)
    lower_curve.setVisible(False)
    color2 = color + (fill_alpha, )
    fill = pg.FillBetweenItem(upper_curve, lower_curve, brush=color2)
    fill.setZValue(-10)
    plots[0].addItem(fill, ignoreBounds=True)

    return plots
예제 #24
0
    pg.dbg()

    all_expts = cached_experiments()
    
    # results = pulse_average_matrix(all_expts, clamp_mode='ic', min_duration=25e-3, pulse_ids=[0, 8])

    # 1. collect a list of all experiments containing connections
    conn_expts = {}
    for c in all_expts.connection_summary():
        conn_expts.setdefault(c['expt'], []).append(c['cells'])

    # 2. for each experiment, get and cache the full set of first-pulse averages
    types = ['sim1', 'tlx3', 'pvalb', 'sst', 'vip']
    indplots = PlotGrid()
    indplots.set_shape(len(types), len(types))
    indplots.show()

    cachefile = "first_pulse_average.pkl"
    cache = pickle.load(open(cachefile, 'rb')) if os.path.isfile(cachefile) else {}

    for expt, conns in sorted(conn_expts.items()):
        if expt.source_id not in cache:
            print "Load:", expt.source_id
            try:
                with expt.data:
                    analyzer = MultiPatchExperimentAnalyzer.get(expt.data)
                    responses = []
                    for pre_cell, post_cell in conns:
                        pre_id, post_id = pre_cell.cell_id-1, post_cell.cell_id-1
                        resp = analyzer.get_evoked_responses(pre_id, post_id, 
                            clamp_mode='ic', 
예제 #25
0
    else:
        c_type = connection.split('-')
        connection_types = [((c_type[0], 'unknown'), (c_type[1], 'unknown'))]

plt = pg.plot()

scale_offset = (-20, -20)
scale_anchor = (0.4, 1)
holding = [-65, -75]
qc_plot = pg.plot()
grand_response = {}
expt_ids = {}
feature_plot = None
feature2_plot = PlotGrid()
feature2_plot.set_shape(5,1)
feature2_plot.show()
feature3_plot = PlotGrid()
feature3_plot.set_shape(1, 3)
feature3_plot.show()
amp_plot = pg.plot()
synapse_plot = PlotGrid()
synapse_plot.set_shape(len(connection_types), 1)
synapse_plot.show()
for c in range(len(connection_types)):
    cre_type = (connection_types[c][0][1], connection_types[c][1][1])
    target_layer = (connection_types[c][0][0], connection_types[c][1][0])
    type = connection_types[c]
    expt_list = all_expts.select(cre_type=cre_type, target_layer=target_layer, calcium=calcium, age=age)
    color = color_palette[c]
    grand_response[type[0]] = {'trace': [], 'amp': [], 'latency': [], 'rise': [], 'dist': [], 'decay':[], 'CV': [], 'amp_measured': []}
    expt_ids[type[0]] = []
예제 #26
0
# model.Dynamics = {'Dep':1, 'Fac':0, 'UR':0, 'SMR':0, 'DSR':0}
# params = {(('2/3', 'unknown'), ('2/3', 'unknown')): [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#           ((None,'rorb'), (None,'rorb')): [0, 506.7, 0, 0, 0, 0, 0.22, 0, 0, 0, 0, 0],
#           ((None,'sim1'), (None,'sim1')): [0, 1213.8, 0, 0, 0, 0, 0.17, 0, 0, 0, 0, 0],
#           ((None,'tlx3'), (None,'tlx3')): [0, 319.4, 0, 0, 0, 0, 0.16, 0, 0, 0, 0, 0]}

ind = data[0]
rec = data[1]
freq = 10
delta = 250
order = human_connections.keys()
delay_order = [250, 500, 1000, 2000, 4000]

ind_plt = PlotGrid()
ind_plt.set_shape(4, 1)
ind_plt.show()
rec_plt = PlotGrid()
rec_plt.set_shape(1, 2)
rec_plt.show()
ind_plt_scatter = pg.plot()
ind_plt_all = PlotGrid()
ind_plt_all.set_shape(1, 2)
ind_plt_all.show()
ind_50 = {}
ninth_pulse_250 = {}
gain_plot = pg.plot()

for t, type in enumerate(order):
    if type == (('2', 'unknown'), ('2', 'unknown')):
        continue
    if type == (('4', 'unknown'), ('4', 'unknown')):
def save_fit_psp_test_set():
    """NOTE THIS CODE DOES NOT WORK BUT IS HERE FOR DOCUMENTATION PURPOSES SO 
    THAT WE CAN TRACE BACK HOW THE TEST DATA WAS CREATED IF NEEDED.
    Create a test set of data for testing the fit_psp function.  Uses Steph's 
    original first_puls_feature.py code to filter out error causing data.
    
    Example run statement
    python save save_fit_psp_test_set.py --organism mouse --connection ee
    
    Comment in the code that does the saving at the bottom
    """
    
    
    import pyqtgraph as pg
    import numpy as np
    import csv
    import sys
    import argparse
    from multipatch_analysis.experiment_list import cached_experiments
    from manuscript_figures import get_response, get_amplitude, response_filter, feature_anova, write_cache, trace_plot, \
        colors_human, colors_mouse, fail_rate, pulse_qc, feature_kw
    from synapse_comparison import load_cache, summary_plot_pulse
    from neuroanalysis.data import TraceList, Trace
    from neuroanalysis.ui.plot_grid import PlotGrid
    from multipatch_analysis.connection_detection import fit_psp
    from rep_connections import ee_connections, human_connections, no_include, all_connections, ie_connections, ii_connections, ei_connections
    from multipatch_analysis.synaptic_dynamics import DynamicsAnalyzer
    from scipy import stats
    import time
    import pandas as pd
    import json
    import os
    
    app = pg.mkQApp()
    pg.dbg()
    pg.setConfigOption('background', 'w')
    pg.setConfigOption('foreground', 'k')
    
    parser = argparse.ArgumentParser(description='Enter organism and type of connection you"d like to analyze ex: mouse ee (all mouse excitatory-'
                    'excitatory). Alternatively enter a cre-type connection ex: sim1-sim1')
    parser.add_argument('--organism', dest='organism', help='Select mouse or human')
    parser.add_argument('--connection', dest='connection', help='Specify connections to analyze')
    args = vars(parser.parse_args(sys.argv[1:]))
    
    all_expts = cached_experiments()
    manifest = {'Type': [], 'Connection': [], 'amp': [], 'latency': [],'rise':[], 'rise2080': [], 'rise1090': [], 'rise1080': [],
                'decay': [], 'nrmse': [], 'CV': []}
    fit_qc = {'nrmse': 8, 'decay': 499e-3}
    
    if args['organism'] == 'mouse':
        color_palette = colors_mouse
        calcium = 'high'
        age = '40-60'
        sweep_threshold = 3
        threshold = 0.03e-3
        connection = args['connection']
        if connection == 'ee':
            connection_types = ee_connections.keys()
        elif connection == 'ii':
            connection_types = ii_connections.keys()
        elif connection == 'ei':
            connection_types = ei_connections.keys()
        elif connection == 'ie':
            connection_types == ie_connections.keys()
        elif connection == 'all':
            connection_types = all_connections.keys()
        elif len(connection.split('-')) == 2:
            c_type = connection.split('-')
            if c_type[0] == '2/3':
                pre_type = ('2/3', 'unknown')
            else:
                pre_type = (None, c_type[0])
            if c_type[1] == '2/3':
                post_type = ('2/3', 'unknown')
            else:
                post_type = (None, c_type[0])
            connection_types = [(pre_type, post_type)]
    elif args['organism'] == 'human':
        color_palette = colors_human
        calcium = None
        age = None
        sweep_threshold = 5
        threshold = None
        connection = args['connection']
        if connection == 'ee':
            connection_types = human_connections.keys()
        else:
            c_type = connection.split('-')
            connection_types = [((c_type[0], 'unknown'), (c_type[1], 'unknown'))]
    
    plt = pg.plot()
    
    scale_offset = (-20, -20)
    scale_anchor = (0.4, 1)
    holding = [-65, -75]
    qc_plot = pg.plot()
    grand_response = {}
    expt_ids = {}
    feature_plot = None
    feature2_plot = PlotGrid()
    feature2_plot.set_shape(5,1)
    feature2_plot.show()
    feature3_plot = PlotGrid()
    feature3_plot.set_shape(1, 3)
    feature3_plot.show()
    amp_plot = pg.plot()
    synapse_plot = PlotGrid()
    synapse_plot.set_shape(len(connection_types), 1)
    synapse_plot.show()
    for c in range(len(connection_types)):
        cre_type = (connection_types[c][0][1], connection_types[c][1][1])
        target_layer = (connection_types[c][0][0], connection_types[c][1][0])
        conn_type = connection_types[c]
        expt_list = all_expts.select(cre_type=cre_type, target_layer=target_layer, calcium=calcium, age=age)
        color = color_palette[c]
        grand_response[conn_type[0]] = {'trace': [], 'amp': [], 'latency': [], 'rise': [], 'dist': [], 'decay':[], 'CV': [], 'amp_measured': []}
        expt_ids[conn_type[0]] = []
        synapse_plot[c, 0].addLegend()
        for expt in expt_list:
            for pre, post in expt.connections:
                if [expt.uid, pre, post] in no_include:
                    continue
                cre_check = expt.cells[pre].cre_type == cre_type[0] and expt.cells[post].cre_type == cre_type[1]
                layer_check = expt.cells[pre].target_layer == target_layer[0] and expt.cells[post].target_layer == target_layer[1]
                if cre_check is True and layer_check is True:
                    pulse_response, artifact = get_response(expt, pre, post, analysis_type='pulse')
                    if threshold is not None and artifact > threshold:
                        continue
                    response_subset, hold = response_filter(pulse_response, freq_range=[0, 50], holding_range=holding, pulse=True)
                    if len(response_subset) >= sweep_threshold:
                        qc_plot.clear()
                        qc_list = pulse_qc(response_subset, baseline=1.5, pulse=None, plot=qc_plot)
                        if len(qc_list) >= sweep_threshold:
                            avg_trace, avg_amp, amp_sign, peak_t = get_amplitude(qc_list)
    #                        if amp_sign is '-':
    #                            continue
    #                        #print ('%s, %0.0f' %((expt.uid, pre, post), hold, ))
    #                        all_amps = fail_rate(response_subset, '+', peak_t)
    #                        cv = np.std(all_amps)/np.mean(all_amps)
    #                        
    #                        # weight parts of the trace during fitting
                            dt = avg_trace.dt
                            weight = np.ones(len(avg_trace.data))*10.  #set everything to ten initially
                            weight[int(10e-3/dt):int(12e-3/dt)] = 0.   #area around stim artifact
                            weight[int(12e-3/dt):int(19e-3/dt)] = 30.  #area around steep PSP rise 
                            
                            # check if the test data dir is there and if not create it
                            test_data_dir='test_psp_fit'
                            if not os.path.isdir(test_data_dir):
                                os.mkdir(test_data_dir)
                                
                            save_dict={}
                            save_dict['input']={'data': avg_trace.data.tolist(),
                                                'dtype': str(avg_trace.data.dtype),
                                                'dt': float(avg_trace.dt),
                                                'amp_sign': amp_sign,
                                                'yoffset': 0, 
                                                'xoffset': 14e-3, 
                                                'avg_amp': float(avg_amp),
                                                'method': 'leastsq', 
                                                'stacked': False, 
                                                'rise_time_mult_factor': 10., 
                                                'weight': weight.tolist()} 
                            
                            # need to remake trace because different output is created
                            avg_trace_simple=Trace(data=np.array(save_dict['input']['data']), dt=save_dict['input']['dt']) # create Trace object
                            
                            psp_fits_original = fit_psp(avg_trace, 
                                               sign=save_dict['input']['amp_sign'], 
                                               yoffset=save_dict['input']['yoffset'], 
                                               xoffset=save_dict['input']['xoffset'], 
                                               amp=save_dict['input']['avg_amp'],
                                               method=save_dict['input']['method'], 
                                               stacked=save_dict['input']['stacked'], 
                                               rise_time_mult_factor=save_dict['input']['rise_time_mult_factor'], 
                                               fit_kws={'weights': save_dict['input']['weight']})  
    
                            psp_fits_simple = fit_psp(avg_trace_simple, 
                                               sign=save_dict['input']['amp_sign'], 
                                               yoffset=save_dict['input']['yoffset'], 
                                               xoffset=save_dict['input']['xoffset'], 
                                               amp=save_dict['input']['avg_amp'],
                                               method=save_dict['input']['method'], 
                                               stacked=save_dict['input']['stacked'], 
                                               rise_time_mult_factor=save_dict['input']['rise_time_mult_factor'], 
                                               fit_kws={'weights': save_dict['input']['weight']})  
                            print expt.uid, pre, post    
                            if psp_fits_original.nrmse()!=psp_fits_simple.nrmse():     
                                print '  the nrmse values dont match'
                                print '\toriginal', psp_fits_original.nrmse()
                                print '\tsimple', psp_fits_simple.nrmse()
class DistancePlot(object):
    def __init__(self):
        self.grid = PlotGrid()
        self.grid.set_shape(2, 1)
        self.grid.grid.ci.layout.setRowStretchFactor(0, 5)
        self.grid.grid.ci.layout.setRowStretchFactor(1, 10)
        self.plots = (self.grid[1,0], self.grid[0,0])
        self.plots[0].grid = self.grid
        self.plots[0].addLegend()
        self.grid.show()
        self.plots[0].setLabels(bottom=('distance', 'm'), left='connection probability')
        self.plots[0].setXRange(0, 200e-6)
        self.params = Parameter.create(name='Distance binning window', type='float', value=40.e-6, step=10.e-6, suffix='m', siPrefix=True)
        self.element_plot = None
        self.elements = []
        self.element_colors = []
        self.results = None
        self.color = None
        self.name = None

        self.params.sigTreeStateChanged.connect(self.update_plot)

    def plot_distance(self, results, color, name, size=10, suppress_scatter=False):
        """Results needs to be a DataFrame or Series object with 'Synapse' and 'Distance' as columns

        """
        connected = results[~results['Distance'].isnull()]['Connected']
        distance = results[~results['Distance'].isnull()]['Distance'] 
        dist_win = self.params.value()
        if self.results is None:
            self.name = name
            self.color = color
            self.results = results
        if suppress_scatter is True:
        #suppress scatter plot for all results (takes forever to plot)
            plots = list(self.plots)
            plots[1] = None
            self.dist_plot = distance_plot(connected, distance, plots=plots, color=color, name=name, size=size, window=dist_win, spacing=dist_win)
        else:
            self.dist_plot = distance_plot(connected, distance, plots=self.plots, color=color, name=name, size=size, window=dist_win, spacing=dist_win)
        
        return self.dist_plot

    def invalidate_output(self):
        self.grid.clear()

    def element_distance(self, element, color, add_to_list=True):
        if add_to_list is True:
            self.element_colors.append(color)
            self.elements.append(element)
        pre = element['pre_class'][0].name
        post = element['post_class'][0].name
        name = ('%s->%s' % (pre, post))
        self.element_plot = self.plot_distance(element, color=color, name=name, size=15)

    def element_distance_reset(self, results, color, name, suppress_scatter=False):
        self.elements = []
        self.element_colors = []
        self.grid.clear()
        self.dist_plot = self.plot_distance(results, color=color, name=name, size=10, suppress_scatter=suppress_scatter)

    def update_plot(self):
        self.invalidate_output()
        self.plot_distance(self.results, self.color, self.name, suppress_scatter=True)
        if self.element_plot is not None:
            for element, color in zip(self.elements, self.element_colors):
                self.element_distance(element, color, add_to_list=False)
    grid1.set_shape(2, 1)
    for i in range(n_responses):
        r = responses.responses[i]
        grid1[0, 0].plot(r.time_values, r.data)
        
        filt = bessel_filter(r - np.median(r.time_slice(0, 10e-3).data), 300.)
        responses.responses[i] = filt
        
        dec = exp_deconvolve(r, 15e-3)
        baseline = np.median(dec.data[:100])
        r2 = bessel_filter(dec-baseline, 300.)
        grid1[1, 0].plot(r2.time_values, r2.data)
        
        deconv.append(r2)
    
    grid1.show()
    

    def measure_amp(trace, baseline=(6e-3, 8e-3), response=(13e-3, 17e-3)):
        baseline = trace.time_slice(*baseline).data.mean()
        peak = trace.time_slice(*response).data.max()
        return peak - baseline

    
    # Chop up responses into groups of varying size and plot the average
    # amplitude as measured from these chunks. We expect to see variance
    # decrease as a function of the number of responses being averaged together.
    x = []
    y_deconv1 = []
    y_deconv2 = []
    y_deconv3 = []
예제 #30
0
class TSeriesPlot(pg.GraphicsLayoutWidget):
    def __init__(self, title, units):
        pg.GraphicsLayoutWidget.__init__(self)
        self.grid = PlotGrid()
        self.grid.set_shape(4, 1)
        self.grid.grid.ci.layout.setRowStretchFactor(0, 3)
        self.grid.grid.ci.layout.setRowStretchFactor(1, 8)
        self.grid.grid.ci.layout.setRowStretchFactor(2, 5)
        self.grid.grid.ci.layout.setRowStretchFactor(3, 10)
        self.grid.show()
        self.trace_plots = (self.grid[1, 0], self.grid[3, 0])
        self.spike_plots = (self.grid[0, 0], self.grid[2, 0])
        self.plots = self.spike_plots + self.trace_plots
        for plot in self.plots[:-1]:
            plot.hideAxis('bottom')
        self.plots[-1].setLabel('bottom', text='Time from spike', units='s')
        self.fit_item_55 = None
        self.fit_item_70 = None
        self.fit_color = {True: 'g', False: 'r'}
        self.qc_color = {'qc_pass': (255, 255, 255, 100), 'qc_fail': (255, 0, 0, 100)}

        self.plots[0].setTitle(title)
        for (plot, holding) in zip(self.trace_plots, holdings):
            plot.setXLink(self.plots[-1])
            plot.setLabel('left', text="%d holding" % int(holding), units=units)
        for plot in self.spike_plots:
            plot.setXLink(self.plots[-1])
            plot.setLabel('left', text="presynaptic spike")
            plot.addLine(x=0)
        self.plots[-1].setXRange(-5e-3, 10e-3)
        
        self.items = []

    def plot_responses(self, pulse_responses):
        self.plot_traces(pulse_responses)
        self.plot_spikes(pulse_responses)

    def plot_traces(self, pulse_responses):  
        for i, holding in enumerate(pulse_responses.keys()):
            for qc, prs in pulse_responses[holding].items():
                if len(prs) == 0:
                    continue
                prl = PulseResponseList(prs)
                post_ts = prl.post_tseries(align='spike', bsub=True)
                
                for trace in post_ts:
                    item = self.trace_plots[i].plot(trace.time_values, trace.data, pen=self.qc_color[qc])
                    if qc == 'qc_fail':
                        item.setZValue(-10)
                    self.items.append(item)
                if qc == 'qc_pass':
                    grand_trace = post_ts.mean()
                    item = self.trace_plots[i].plot(grand_trace.time_values, grand_trace.data, pen={'color': 'b', 'width': 2})
                    self.items.append(item)
            self.trace_plots[i].autoRange()
            self.trace_plots[i].setXRange(-5e-3, 10e-3)
            # y_range = [grand_trace.data.min(), grand_trace.data.max()]
            # self.plots[i].setYRange(y_range[0], y_range[1], padding=1)

    def plot_spikes(self, pulse_responses):
        for i, holding in enumerate(pulse_responses.keys()):
            for prs in pulse_responses[holding].values():
                if len(prs) == 0:
                    continue
                prl = PulseResponseList(prs)
                pre_ts = prl.pre_tseries(align='spike', bsub=True)
                for pr, spike in zip(prl, pre_ts):
                    qc = 'qc_pass' if pr.stim_pulse.n_spikes == 1 else 'qc_fail'
                    item = self.spike_plots[i].plot(spike.time_values, spike.data, pen=self.qc_color[qc])
                    if qc == 'qc_fail':
                        item.setZValue(-10)
                    self.items.append(item)

    def plot_fit(self, trace, holding, fit_pass=False):
        if holding == -55:
            if self.fit_item_55 is not None:
                self.trace_plots[0].removeItem(self.fit_item_55)
            self.fit_item_55 = pg.PlotDataItem(trace.time_values, trace.data, name='-55 holding', pen={'color': self.fit_color[fit_pass], 'width': 3})
            self.trace_plots[0].addItem(self.fit_item_55)
        
        elif holding == -70:
            if self.fit_item_70 is not None:
                self.trace_plots[1].removeItem(self.fit_item_70)
            self.fit_item_70 = pg.PlotDataItem(trace.time_values, trace.data, name='-70 holding', pen={'color': self.fit_color[fit_pass], 'width': 3})
            self.trace_plots[1].addItem(self.fit_item_70)

    def color_fit(self, name, value):
        if '-55' in name:
            if self.fit_item_55 is not None:
                self.fit_item_55.setPen({'color': self.fit_color[value], 'width': 3})
        if '-70' in name:
            if self.fit_item_70 is not None:
                self.fit_item_70.setPen({'color': self.fit_color[value], 'width': 3})

    def clear_plots(self):
        for item in self.items + [self.fit_item_55, self.fit_item_70]:
            if item is None:
                continue
            item.scene().removeItem(item)
        self.items = []
        
        self.plots[-1].autoRange()
        self.plots[-1].setXRange(-5e-3, 10e-3)
        self.fit_item_70 = None
        self.fit_item_55 = None
예제 #31
0
    grid1.set_shape(2, 1)
    for i in range(n_responses):
        r = responses.responses[i]
        grid1[0, 0].plot(r.time_values, r.data)
        
        filt = bessel_filter(r - np.median(r.time_slice(0, 10e-3).data), 300.)
        responses.responses[i] = filt
        
        dec = exp_deconvolve(r, 15e-3)
        baseline = np.median(dec.data[:100])
        r2 = bessel_filter(dec-baseline, 300.)
        grid1[1, 0].plot(r2.time_values, r2.data)
        
        deconv.append(r2)
    
    grid1.show()
    

    def measure_amp(trace, baseline=(6e-3, 8e-3), response=(13e-3, 17e-3)):
        baseline = trace.time_slice(*baseline).data.mean()
        peak = trace.time_slice(*response).data.max()
        return peak - baseline

    
    # Chop up responses into groups of varying size and plot the average
    # amplitude as measured from these chunks. We expect to see variance
    # decrease as a function of the number of responses being averaged together.
    x = []
    y_deconv1 = []
    y_deconv2 = []
    y_deconv3 = []
예제 #32
0
class PairView(QtGui.QWidget):
    """For analyzing pre/post-synaptic pairs.
    """
    def __init__(self, parent=None):
        self.sweeps = []
        self.channels = []

        self.current_event_set = None
        self.event_sets = []

        QtGui.QWidget.__init__(self, parent)

        self.layout = QtGui.QGridLayout()
        self.setLayout(self.layout)
        self.layout.setContentsMargins(0, 0, 0, 0)

        self.vsplit = QtGui.QSplitter(QtCore.Qt.Vertical)
        self.layout.addWidget(self.vsplit, 0, 0)

        self.pre_plot = pg.PlotWidget()
        self.post_plot = pg.PlotWidget()
        for plt in (self.pre_plot, self.post_plot):
            plt.setClipToView(True)
            plt.setDownsampling(True, True, 'peak')

        self.post_plot.setXLink(self.pre_plot)
        self.vsplit.addWidget(self.pre_plot)
        self.vsplit.addWidget(self.post_plot)

        self.response_plots = PlotGrid()
        self.vsplit.addWidget(self.response_plots)

        self.event_splitter = QtGui.QSplitter(QtCore.Qt.Horizontal)
        self.vsplit.addWidget(self.event_splitter)

        self.event_table = pg.TableWidget()
        self.event_splitter.addWidget(self.event_table)

        self.event_widget = QtGui.QWidget()
        self.event_widget_layout = QtGui.QGridLayout()
        self.event_widget.setLayout(self.event_widget_layout)
        self.event_splitter.addWidget(self.event_widget)

        self.event_splitter.setSizes([600, 100])

        self.event_set_list = QtGui.QListWidget()
        self.event_widget_layout.addWidget(self.event_set_list, 0, 0, 1, 2)
        self.event_set_list.addItem("current")
        self.event_set_list.itemSelectionChanged.connect(
            self.event_set_selected)

        self.add_set_btn = QtGui.QPushButton('add')
        self.event_widget_layout.addWidget(self.add_set_btn, 1, 0, 1, 1)
        self.add_set_btn.clicked.connect(self.add_set_clicked)
        self.remove_set_btn = QtGui.QPushButton('del')
        self.event_widget_layout.addWidget(self.remove_set_btn, 1, 1, 1, 1)
        self.remove_set_btn.clicked.connect(self.remove_set_clicked)
        self.fit_btn = QtGui.QPushButton('fit all')
        self.event_widget_layout.addWidget(self.fit_btn, 2, 0, 1, 2)
        self.fit_btn.clicked.connect(self.fit_clicked)

        self.fit_plot = PlotGrid()

        self.fit_explorer = None

        self.artifact_remover = ArtifactRemover(user_width=True)
        self.baseline_remover = BaselineRemover()
        self.filter = SignalFilter()

        self.params = pg.parametertree.Parameter(
            name='params',
            type='group',
            children=[{
                'name': 'pre',
                'type': 'list',
                'values': []
            }, {
                'name': 'post',
                'type': 'list',
                'values': []
            }, self.artifact_remover.params, self.baseline_remover.params,
                      self.filter.params, {
                          'name': 'time constant',
                          'type': 'float',
                          'suffix': 's',
                          'siPrefix': True,
                          'value': 10e-3,
                          'dec': True,
                          'minStep': 100e-6
                      }])
        self.params.sigTreeStateChanged.connect(self._update_plots)

    def data_selected(self, sweeps, channels):
        self.sweeps = sweeps
        self.channels = channels

        self.params.child('pre').setLimits(channels)
        self.params.child('post').setLimits(channels)

        self._update_plots()

    def _update_plots(self):
        sweeps = self.sweeps
        self.current_event_set = None
        self.event_table.clear()

        # clear all plots
        self.pre_plot.clear()
        self.post_plot.clear()

        pre = self.params['pre']
        post = self.params['post']

        # If there are no selected sweeps or channels have not been set, return
        if len(
                sweeps
        ) == 0 or pre == post or pre not in self.channels or post not in self.channels:
            return

        pre_mode = sweeps[0][pre].clamp_mode
        post_mode = sweeps[0][post].clamp_mode
        for ch, mode, plot in [(pre, pre_mode, self.pre_plot),
                               (post, post_mode, self.post_plot)]:
            units = 'A' if mode == 'vc' else 'V'
            plot.setLabels(left=("Channel %d" % ch, units),
                           bottom=("Time", 's'))

        # Iterate over selected channels of all sweeps, plotting traces one at a time
        # Collect information about pulses and spikes
        pulses = []
        spikes = []
        post_traces = []
        for i, sweep in enumerate(sweeps):
            pre_trace = sweep[pre]['primary']
            post_trace = sweep[post]['primary']

            # Detect pulse times
            stim = sweep[pre]['command'].data
            sdiff = np.diff(stim)
            on_times = np.argwhere(sdiff > 0)[1:, 0]  # 1: skips test pulse
            off_times = np.argwhere(sdiff < 0)[1:, 0]
            pulses.append(on_times)

            # filter data
            post_filt = self.artifact_remover.process(
                post_trace,
                list(on_times) + list(off_times))
            post_filt = self.baseline_remover.process(post_filt)
            post_filt = self.filter.process(post_filt)
            post_traces.append(post_filt)

            # plot raw data
            color = pg.intColor(i, hues=len(sweeps) * 1.3, sat=128)
            color.setAlpha(128)
            for trace, plot in [(pre_trace, self.pre_plot),
                                (post_filt, self.post_plot)]:
                plot.plot(trace.time_values,
                          trace.data,
                          pen=color,
                          antialias=False)

            # detect spike times
            spike_inds = []
            spike_info = []
            for on, off in zip(on_times, off_times):
                spike = detect_evoked_spike(sweep[pre], [on, off])
                spike_info.append(spike)
                if spike is None:
                    spike_inds.append(None)
                else:
                    spike_inds.append(spike['rise_index'])
            spikes.append(spike_info)

            dt = pre_trace.dt
            vticks = pg.VTickGroup(
                [x * dt for x in spike_inds if x is not None],
                yrange=[0.0, 0.2],
                pen=color)
            self.pre_plot.addItem(vticks)

        # Iterate over spikes, plotting average response
        all_responses = []
        avg_responses = []
        fits = []
        fit = None

        npulses = max(map(len, pulses))
        self.response_plots.clear()
        self.response_plots.set_shape(1, npulses +
                                      1)  # 1 extra for global average
        self.response_plots.setYLink(self.response_plots[0, 0])
        for i in range(1, npulses + 1):
            self.response_plots[0, i].hideAxis('left')
        units = 'A' if post_mode == 'vc' else 'V'
        self.response_plots[0,
                            0].setLabels(left=("Averaged events (Channel %d)" %
                                               post, units))

        fit_pen = {'color': (30, 30, 255), 'width': 2, 'dash': [1, 1]}
        for i in range(npulses):
            # get the chunk of each sweep between spikes
            responses = []
            all_responses.append(responses)
            for j, sweep in enumerate(sweeps):
                # get the current spike
                if i >= len(spikes[j]):
                    continue
                spike = spikes[j][i]
                if spike is None:
                    continue

                # find next spike
                next_spike = None
                for sp in spikes[j][i + 1:]:
                    if sp is not None:
                        next_spike = sp
                        break

                # determine time range for response
                max_len = int(40e-3 /
                              dt)  # don't take more than 50ms for any response
                start = spike['rise_index']
                if next_spike is not None:
                    stop = min(start + max_len, next_spike['rise_index'])
                else:
                    stop = start + max_len

                # collect data from this trace
                trace = post_traces[j]
                d = trace.data[start:stop].copy()
                responses.append(d)

            if len(responses) == 0:
                continue

            # extend all responses to the same length and take nanmean
            avg = ragged_mean(responses, method='clip')
            avg -= float_mode(avg[:int(1e-3 / dt)])
            avg_responses.append(avg)

            # plot average response for this pulse
            start = np.median(
                [sp[i]['rise_index']
                 for sp in spikes if sp[i] is not None]) * dt
            t = np.arange(len(avg)) * dt
            self.response_plots[0, i].plot(t, avg, pen='w', antialias=True)

            # fit!
            fit = self.fit_psp(avg, t, dt, post_mode)
            fits.append(fit)

            # let the user mess with this fit
            curve = self.response_plots[0, i].plot(t,
                                                   fit.eval(),
                                                   pen=fit_pen,
                                                   antialias=True).curve
            curve.setClickable(True)
            curve.fit = fit
            curve.sigClicked.connect(self.fit_curve_clicked)

        # display global average
        global_avg = ragged_mean(avg_responses, method='clip')
        t = np.arange(len(global_avg)) * dt
        self.response_plots[0, -1].plot(t, global_avg, pen='w', antialias=True)
        global_fit = self.fit_psp(global_avg, t, dt, post_mode)
        self.response_plots[0, -1].plot(t,
                                        global_fit.eval(),
                                        pen=fit_pen,
                                        antialias=True)

        # display fit parameters in table
        events = []
        for i, f in enumerate(fits + [global_fit]):
            if f is None:
                continue
            if i >= len(fits):
                vals = OrderedDict([('id', 'avg'), ('spike_time', np.nan),
                                    ('spike_stdev', np.nan)])
            else:
                spt = [
                    s[i]['peak_index'] * dt for s in spikes if s[i] is not None
                ]
                vals = OrderedDict([('id', i), ('spike_time', np.mean(spt)),
                                    ('spike_stdev', np.std(spt))])
            vals.update(
                OrderedDict([(k, f.best_values[k]) for k in f.params.keys()]))
            events.append(vals)

        self.current_event_set = (pre, post, events, sweeps)
        self.event_set_list.setCurrentRow(0)
        self.event_set_selected()

    def fit_psp(self, data, t, dt, clamp_mode):
        mode = float_mode(data[:int(1e-3 / dt)])
        sign = -1 if data.mean() - mode < 0 else 1
        params = OrderedDict([
            ('xoffset', (2e-3, 3e-4, 5e-3)),
            ('yoffset', data[0]),
            ('amp', sign * 10e-12),
            #('k', (2e-3, 50e-6, 10e-3)),
            ('rise_time', (2e-3, 50e-6, 10e-3)),
            ('decay_tau', (4e-3, 500e-6, 50e-3)),
            ('rise_power', (2.0, 'fixed')),
        ])
        if clamp_mode == 'ic':
            params['amp'] = sign * 10e-3
            #params['k'] = (5e-3, 50e-6, 20e-3)
            params['rise_time'] = (5e-3, 50e-6, 20e-3)
            params['decay_tau'] = (15e-3, 500e-6, 150e-3)

        fit_kws = {'xtol': 1e-3, 'maxfev': 100}

        psp = fitting.Psp()
        return psp.fit(data, x=t, fit_kws=fit_kws, params=params)

    def add_set_clicked(self):
        if self.current_event_set is None:
            return
        ces = self.current_event_set
        self.event_sets.append(ces)
        item = QtGui.QListWidgetItem("%d -> %d" % (ces[0], ces[1]))
        self.event_set_list.addItem(item)
        item.event_set = ces

    def remove_set_clicked(self):
        if self.event_set_list.currentRow() == 0:
            return
        sel = self.event_set_list.takeItem(self.event_set_list.currentRow())
        self.event_sets.remove(sel.event_set)

    def event_set_selected(self):
        sel = self.event_set_list.selectedItems()[0]
        if sel.text() == "current":
            self.event_table.setData(self.current_event_set[2])
        else:
            self.event_table.setData(sel.event_set[2])

    def fit_clicked(self):
        from synaptic_release import ReleaseModel

        self.fit_plot.clear()
        self.fit_plot.show()

        n_sets = len(self.event_sets)
        self.fit_plot.set_shape(n_sets, 1)
        l = self.fit_plot[0, 0].legend
        if l is not None:
            l.scene.removeItem(l)
            self.fit_plot[0, 0].legend = None
        self.fit_plot[0, 0].addLegend()

        for i in range(n_sets):
            self.fit_plot[i, 0].setXLink(self.fit_plot[0, 0])

        spike_sets = []
        for i, evset in enumerate(self.event_sets):
            evset = evset[2][:-1]  # select events, skip last row (average)
            x = np.array([ev['spike_time'] for ev in evset])
            y = np.array([ev['amp'] for ev in evset])
            x -= x[0]
            x *= 1000
            y /= y[0]
            spike_sets.append((x, y))

            self.fit_plot[i, 0].plot(x / 1000., y, pen=None, symbol='o')

        model = ReleaseModel()
        dynamics_types = ['Dep', 'Fac', 'UR', 'SMR', 'DSR']
        for k in dynamics_types:
            model.Dynamics[k] = 0

        fit_params = []
        with pg.ProgressDialog("Fitting release model..", 0,
                               len(dynamics_types)) as dlg:
            for k in dynamics_types:
                model.Dynamics[k] = 1
                fit_params.append(model.run_fit(spike_sets))
                dlg += 1
                if dlg.wasCanceled():
                    return

        max_color = len(fit_params) * 1.5

        for i, params in enumerate(fit_params):
            for j, spikes in enumerate(spike_sets):
                x, y = spikes
                t = np.linspace(0, x.max(), 1000)
                output = model.eval(x, params.values())
                y = output[:, 1]
                x = output[:, 0] / 1000.
                self.fit_plot[j, 0].plot(x,
                                         y,
                                         pen=(i, max_color),
                                         name=dynamics_types[i])

    def fit_curve_clicked(self, curve):
        if self.fit_explorer is None:
            self.fit_explorer = FitExplorer(curve.fit)
        else:
            self.fit_explorer.set_fit(curve.fit)
        self.fit_explorer.show()
holding = [-55, -75]
freqs = [10, 20, 50, 100]
rec_t = [250, 500, 1000, 2000, 4000]
sweep_threshold = 3
deconv = True

# cache_file = 'train_response_cache.pkl'
# response_cache = load_cache(cache_file)
# cache_change = []
# log_rec_plt = pg.plot()
# log_rec_plt.setLogMode(x=True)
qc_plot = pg.plot()
ind_plot = PlotGrid()
ind_plot.set_shape(4, len(connection_types))
ind_plot.show()
rec_plot = PlotGrid()
rec_plot.set_shape(5, len(connection_types))
rec_plot.show()
if deconv is True:
    deconv_ind_plot = PlotGrid()
    deconv_ind_plot.set_shape(4, len(connection_types))
    deconv_ind_plot.show()
    deconv_rec_plot = PlotGrid()
    deconv_rec_plot.set_shape(5,len(connection_types))
    deconv_rec_plot.show()
summary_plot = PlotGrid()
summary_plot.set_shape(len(connection_types), 2)
summary_plot.show()
symbols = ['o', 's', 'd', '+', 't']
trace_color = (0, 0, 0, 5)
예제 #34
0
def plot_response_averages(expt, show_baseline=False, **kwds):
    analyzer = MultiPatchExperimentAnalyzer.get(expt)
    devs = analyzer.list_devs()

    # First get average evoked responses for all pre/post pairs
    responses, rows, cols = analyzer.get_evoked_response_matrix(**kwds)

    # resize plot grid accordingly
    plots = PlotGrid()
    plots.set_shape(len(rows), len(cols))
    plots.show() 
    
    ranges = [([], []), ([], [])]
    points = []

    # Plot each matrix element with PSP fit
    for i, dev1 in enumerate(rows):
        for j, dev2 in enumerate(cols):
            # select plot and hide axes
            plt = plots[i, j]
            if i < len(devs) - 1:
                plt.getAxis('bottom').setVisible(False)
            if j > 0:
                plt.getAxis('left').setVisible(False)

            if dev1 == dev2:
                plt.getAxis('bottom').setVisible(False)
                plt.getAxis('left').setVisible(False)
                continue
            
            # adjust axes / labels
            plt.setXLink(plots[0, 0])
            plt.setYLink(plots[0, 0])
            plt.addLine(x=10e-3, pen=0.3)
            plt.addLine(y=0, pen=0.3)
            plt.setLabels(bottom=(str(dev2), 's'))
            if kwds.get('clamp_mode', 'ic') == 'ic':
                plt.setLabels(left=('%s' % dev1, 'V'))
            else:
                plt.setLabels(left=('%s' % dev1, 'A'))

            
            # print "==========", dev1, dev2
            avg_response = responses[(dev1, dev2)].bsub_mean()
            if avg_response is not None:
                avg_response.t0 = 0
                t = avg_response.time_values
                y = bessel_filter(Trace(avg_response.data, dt=avg_response.dt), 2e3).data
                plt.plot(t, y, antialias=True)

                # fit!                
                #fit = responses[(dev1, dev2)].fit_psp(yoffset=0, mask_stim_artifact=(abs(dev1-dev2) < 3))
                #lsnr = np.log(fit.snr)
                #lerr = np.log(fit.nrmse())
                
                #color = (
                    #np.clip(255 * (-lerr/3.), 0, 255),
                    #np.clip(50 * lsnr, 0, 255),
                    #np.clip(255 * (1+lerr/3.), 0, 255)
                #)

                #plt.plot(t, fit.best_fit, pen=color)
                ## plt.plot(t, fit.init_fit, pen='y')

                #points.append({'x': lerr, 'y': lsnr, 'brush': color})

                #if show_baseline:
                    ## plot baseline for reference
                    #bl = avg_response.meta['baseline'] - avg_response.meta['baseline_med']
                    #plt.plot(np.arange(len(bl)) * avg_response.dt, bl, pen=(0, 100, 0), antialias=True)

                # keep track of data range across all plots
                ranges[0][0].append(y.min())
                ranges[0][1].append(y.max())
                ranges[1][0].append(t[0])
                ranges[1][1].append(t[-1])

    plots[0,0].setYRange(min(ranges[0][0]), max(ranges[0][1]))
    plots[0,0].setXRange(min(ranges[1][0]), max(ranges[1][1]))

    # scatter plot of SNR vs NRMSE
    plt = pg.plot()
    plt.setLabels(left='ln(SNR)', bottom='ln(NRMSE)')
    plt.plot([p['x'] for p in points], [p['y'] for p in points], pen=None, symbol='o', symbolBrush=[pg.mkBrush(p['brush']) for p in points])
    # show threshold line
    line = pg.InfiniteLine(pos=[0, 6], angle=180/np.pi * np.arctan(1))
    plt.addItem(line, ignoreBounds=True)

    return plots
예제 #35
0
    if connection == 'ee':
        connection_types = human_connections.keys()
    else:
        c_type = connection.split('-')
        connection_types = [((c_type[0], 'unknown'), (c_type[1], 'unknown'))]

sweep_threshold = 5
threshold = 0.03e-3
scale_offset = (-20, -20)
scale_anchor = (0.4, 1)
qc_plot = pg.plot()
grand_response = {}
feature_plot = None
synapse_plot = PlotGrid()
synapse_plot.set_shape(len(connection_types), 1)
synapse_plot.show()
for c in range(len(connection_types)):
    cre_type = (connection_types[c][0][1], connection_types[c][1][1])
    target_layer = (connection_types[c][0][0], connection_types[c][1][0])
    type = connection_types[c]
    expt_list = all_expts.select(cre_type=cre_type,
                                 target_layer=target_layer,
                                 calcium=calcium,
                                 age=age)
    color = color_palette[c]
    grand_response[type[0]] = {
        'trace': [],
        'amp': [],
        'latency': [],
        'rise': [],
        'decay': [],
예제 #36
0
# Load test data
data = np.load('test_data/evoked_spikes/vc_evoked_spikes.npz')['arr_0']
dt = 20e-6

# gaussian filtering constant
sigma = 20e-6 / dt

# Initialize Qt
pg.mkQApp()
pg.dbg()

# Create a window with a grid of plots (N rows, 1 column)
win = PlotGrid()
win.set_shape(data.shape[0], 1)
win.show()

# Loop over all 10 channels
for i in range(data.shape[0]):
    # select the data for this channel
    trace = data[i, :, 0]
    stim = data[i, :, 1]

    # select the plot we will use for this trace
    plot = win[i, 0]

    # link all x-axes together
    plot.setXLink(win[0, 0])
    xaxis = plot.getAxis('bottom')
    if i == data.shape[0] - 1:
        xaxis.setLabel('Time', 's')
def plot_features(organism=None, conn_type=None, calcium=None, age=None, sweep_thresh=None, fit_thresh=None):
    s = db.Session()

    filters = {
        'organism': organism,
        'conn_type': conn_type,
        'calcium': calcium,
        'age': age
    }

    selection = [{}]
    for key, value in filters.iteritems():
        if value is not None:
            temp_list = []
            value_list = value.split(',')
            for v in value_list:
                temp = [s1.copy() for s1 in selection]
                for t in temp:
                    t[key] = v
                temp_list = temp_list + temp
            selection = list(temp_list)

    if len(selection) > 0:

        response_grid = PlotGrid()
        response_grid.set_shape(len(selection), 1)
        response_grid.show()
        feature_grid = PlotGrid()
        feature_grid.set_shape(6, 1)
        feature_grid.show()

        for i, select in enumerate(selection):
            pre_cell = db.aliased(db.Cell)
            post_cell = db.aliased(db.Cell)
            q_filter = []
            if sweep_thresh is not None:
                q_filter.append(FirstPulseFeatures.n_sweeps>=sweep_thresh)
            species = select.get('organism')
            if species is not None:
                q_filter.append(db.Slice.species==species)
            c_type = select.get('conn_type')
            if c_type is not None:
                pre_type, post_type = c_type.split('-')
                pre_layer, pre_cre = pre_type.split(';')
                if pre_layer == 'None':
                    pre_layer = None
                post_layer, post_cre = post_type.split(';')
                if post_layer == 'None':
                    post_layer = None
                q_filter.extend([pre_cell.cre_type==pre_cre, pre_cell.target_layer==pre_layer,
                                post_cell.cre_type==post_cre, post_cell.target_layer==post_layer])
            calc_conc = select.get('calcium')
            if calc_conc is not None:
                q_filter.append(db.Experiment.acsf.like(calc_conc + '%'))
            age_range = select.get('age')
            if age_range is not None:
                age_lower, age_upper = age_range.split('-')
                q_filter.append(db.Slice.age.between(int(age_lower), int(age_upper)))

            q = s.query(FirstPulseFeatures).join(db.Pair, FirstPulseFeatures.pair_id==db.Pair.id)\
                .join(pre_cell, db.Pair.pre_cell_id==pre_cell.id)\
                .join(post_cell, db.Pair.post_cell_id==post_cell.id)\
                .join(db.Experiment, db.Experiment.id==db.Pair.expt_id)\
                .join(db.Slice, db.Slice.id==db.Experiment.slice_id)

            for filter_arg in q_filter:
                q = q.filter(filter_arg)

            results = q.all()

            trace_list = []
            for pair in results:
                #TODO set t0 to latency to align to foot of PSP
                trace = Trace(data=pair.avg_psp, sample_rate=db.default_sample_rate)
                trace_list.append(trace)
                response_grid[i, 0].plot(trace.time_values, trace.data)
            if len(trace_list) > 0:
                grand_trace = TraceList(trace_list).mean()
                response_grid[i, 0].plot(grand_trace.time_values, grand_trace.data, pen='b')
                response_grid[i, 0].setTitle('layer %s, %s-> layer %s, %s; n_synapses = %d' %
                                          (pre_layer, pre_cre, post_layer, post_cre, len(trace_list)))
            else:
                print('No synapses for layer %s, %s-> layer %s, %s' % (pre_layer, pre_cre, post_layer, post_cre))

    return response_grid, feature_grid
예제 #38
0
def plot_features(organism=None,
                  conn_type=None,
                  calcium=None,
                  age=None,
                  sweep_thresh=None,
                  fit_thresh=None):
    s = db.session()

    filters = {
        'organism': organism,
        'conn_type': conn_type,
        'calcium': calcium,
        'age': age
    }

    selection = [{}]
    for key, value in filters.iteritems():
        if value is not None:
            temp_list = []
            value_list = value.split(',')
            for v in value_list:
                temp = [s1.copy() for s1 in selection]
                for t in temp:
                    t[key] = v
                temp_list = temp_list + temp
            selection = list(temp_list)

    if len(selection) > 0:

        response_grid = PlotGrid()
        response_grid.set_shape(len(selection), 1)
        response_grid.show()
        feature_grid = PlotGrid()
        feature_grid.set_shape(6, 1)
        feature_grid.show()

        for i, select in enumerate(selection):
            pre_cell = aliased(db.Cell)
            post_cell = aliased(db.Cell)
            q_filter = []
            if sweep_thresh is not None:
                q_filter.append(FirstPulseFeatures.n_sweeps >= sweep_thresh)
            species = select.get('organism')
            if species is not None:
                q_filter.append(db.Slice.species == species)
            c_type = select.get('conn_type')
            if c_type is not None:
                pre_type, post_type = c_type.split('-')
                pre_layer, pre_cre = pre_type.split(';')
                if pre_layer == 'None':
                    pre_layer = None
                post_layer, post_cre = post_type.split(';')
                if post_layer == 'None':
                    post_layer = None
                q_filter.extend([
                    pre_cell.cre_type == pre_cre,
                    pre_cell.target_layer == pre_layer,
                    post_cell.cre_type == post_cre,
                    post_cell.target_layer == post_layer
                ])
            calc_conc = select.get('calcium')
            if calc_conc is not None:
                q_filter.append(db.Experiment.acsf.like(calc_conc + '%'))
            age_range = select.get('age')
            if age_range is not None:
                age_lower, age_upper = age_range.split('-')
                q_filter.append(
                    db.Slice.age.between(int(age_lower), int(age_upper)))

            q = s.query(FirstPulseFeatures).join(db.Pair, FirstPulseFeatures.pair_id==db.Pair.id)\
                .join(pre_cell, db.Pair.pre_cell_id==pre_cell.id)\
                .join(post_cell, db.Pair.post_cell_id==post_cell.id)\
                .join(db.Experiment, db.Experiment.id==db.Pair.expt_id)\
                .join(db.Slice, db.Slice.id==db.Experiment.slice_id)

            for filter_arg in q_filter:
                q = q.filter(filter_arg)

            results = q.all()

            trace_list = []
            for pair in results:
                #TODO set t0 to latency to align to foot of PSP
                trace = TSeries(data=pair.avg_psp,
                                sample_rate=db.default_sample_rate)
                trace_list.append(trace)
                response_grid[i, 0].plot(trace.time_values, trace.data)
            if len(trace_list) > 0:
                grand_trace = TSeriesList(trace_list).mean()
                response_grid[i, 0].plot(grand_trace.time_values,
                                         grand_trace.data,
                                         pen='b')
                response_grid[i, 0].setTitle(
                    'layer %s, %s-> layer %s, %s; n_synapses = %d' %
                    (pre_layer, pre_cre, post_layer, post_cre,
                     len(trace_list)))
            else:
                print('No synapses for layer %s, %s-> layer %s, %s' %
                      (pre_layer, pre_cre, post_layer, post_cre))

    return response_grid, feature_grid
# plot options
parser.add_argument('--sweeps', action='store_true', default=False, dest='sweeps',
                    help='plot individual sweeps behing average')
parser.add_argument('--trains', action='store_true', default=False, dest='trains', help='plot 50Hz train and deconvolution')
parser.add_argument('--link-y-axis', action='store_true', default=False, dest='link-y-axis', help='link all y-axis down a column')

args = vars(parser.parse_args(sys.argv[1:]))
plot_sweeps = args['sweeps']
plot_trains = args['trains']
link_y_axis = args['link-y-axis']
expt_cache = 'C:/Users/Stephanies/multipatch_analysis/tools/expts_cache.pkl'
all_expts = cached_experiments()

test = PlotGrid()
test.set_shape(len(connection_types.keys()), 1)
test.show()
grid = PlotGrid()
if plot_trains is True:
    grid.set_shape(len(connection_types.keys()), 3)
    grid[0, 1].setTitle(title='50 Hz Train')
    grid[0, 2].setTitle(title='Exponential Deconvolution')
    tau = 15e-3
    lp = 1000
else:
    grid.set_shape(len(connection_types.keys()), 2)
grid.show()
row = 0
grid[0, 0].setTitle(title='First Pulse')

maxYpulse = []
maxYtrain = []