Esempio n. 1
0
    def _get_tserieslist(self, ts_name, align, bsub):
        tsl = []
        for pr in self.prs:
            ts = getattr(pr, ts_name)
            stim_time = pr.stim_pulse.onset_time

            if bsub is True:
                start_time = max(ts.t0, stim_time - 5e-3)
                baseline_data = ts.time_slice(start_time, stim_time).data
                if len(baseline_data) == 0:
                    baseline = ts.data[0]
                else:
                    baseline = float_mode(baseline_data)
                ts = ts - baseline

            if align is not None:
                if align == 'spike':
                    align_t = pr.stim_pulse.first_spike_time
                    # ignore PRs with no known spike time
                    if align_t is None:
                        continue
                elif align == 'pulse':
                    align_t = stim_time
                else:
                    raise ValueError(
                        "align must be None, 'spike', or 'pulse'.")
                ts = ts.copy(t0=ts.t0 - align_t)

            tsl.append(ts)
        return TSeriesList(tsl)
Esempio n. 2
0
def first_pulse_features(pair, pulse_responses, pulse_response_amps):

    avg_psp = TSeriesList(pulse_responses).mean()
    dt = avg_psp.dt
    avg_psp_baseline = float_mode(avg_psp.data[:int(10e-3 / dt)])
    avg_psp_bsub = avg_psp.copy(data=avg_psp.data - avg_psp_baseline)
    lower_bound = -float('inf')
    upper_bound = float('inf')
    xoffset = pair.synapse_prediction.ic_fit_xoffset
    if xoffset is None:
        xoffset = 14 * 10e-3
    synapse_type = pair.synapse_prediction.synapse_type
    if synapse_type == 'ex':
        amp_sign = '+'
    elif synapse_type == 'in':
        amp_sign = '-'
    else:
        raise Exception(
            'Synapse type is not defined, reconsider fitting this pair %s %d->%d'
            % (pair.expt_id, pair.pre_cell_id, pair.post_cell_id))

    weight = np.ones(len(
        avg_psp.data)) * 10.  # set everything to ten initially
    weight[int(10e-3 / dt):int(12e-3 / dt)] = 0.  # area around stim artifact
    weight[int(12e-3 / dt):int(19e-3 / dt)] = 30.  # area around steep PSP rise

    psp_fits = fit_psp(avg_psp,
                       xoffset=(xoffset, lower_bound, upper_bound),
                       yoffset=(avg_psp_baseline, lower_bound, upper_bound),
                       sign=amp_sign,
                       weight=weight)

    amp_cv = np.std(pulse_response_amps) / np.mean(pulse_response_amps)

    features = {
        'ic_fit_amp': psp_fits.best_values['amp'],
        'ic_fit_latency': psp_fits.best_values['xoffset'] - 10e-3,
        'ic_fit_rise_time': psp_fits.best_values['rise_time'],
        'ic_fit_decay_tau': psp_fits.best_values['decay_tau'],
        'ic_amp_cv': amp_cv,
        'avg_psp': avg_psp_bsub.data
    }
    #'ic_fit_NRMSE': psp_fits.nrmse()} TODO: nrmse not returned from psp_fits?

    return features
Esempio n. 3
0
 def clicked(sp, pts):
     data = pts[0].data()
     print("-----------------------\nclicked:", data['rise_time'],
           data['amp'], data['prediction'], data['confidence'])
     for r in data['results']:
         print({k: r[k] for k in classifier.features})
     traces = data['traces']
     plt = pg.plot()
     bsub = [
         t.copy(data=t.data - np.median(t.time_slice(0, 1e-3).data))
         for t in traces
     ]
     for t in bsub:
         plt.plot(t.time_values, t.data, pen=(0, 0, 0, 50))
     mean = TSeriesList(bsub).mean()
     plt.plot(mean.time_values, mean.data, pen='g')
Esempio n. 4
0
def get_average_pulse_response(pair, desired_clamp='ic'):
    """
    Inputs
    ------
    pair: aisynphys.database.database.Pair object

    desired_clamp: string
        Specifies whether current or voltage clamp sweeps are desired.
        Options are:
            'ic': current clamp
            'vc': voltage clamp

    Returns
    -------
    Note that all returned variables are set to None if there are no acceptable (qc pasing) sweeps
    pulse_responses: TSeriesList 
        traces where the start of each trace is 10 ms before the spike 
    pulse_ids: list of ints
        pulse ids of *pulse_responses*
    psp_amps_measured: list of floats
        amplitude of *pulse_responses* from the *pulse_response* table
    freq: list of floats
        the stimulation frequency corresponding to the *pulse_responses* 
    avg_psp: TSeries
        average of the pulse_responses
    measured_relative_amp: float
        measured amplitude relative to baseline
    measured_baseline: float
        value of baseline
    """
    # get pulses that pass qc
    pulse_responses, pulse_ids, psp_amps_measured, freq = extract_first_pulse_info_from_Pair_object(
        pair, desired_clamp=desired_clamp)

    # if pulses are returned take the average
    if len(pulse_responses) > 0:
        avg_psp = TSeriesList(pulse_responses).mean()
    else:
        return None, None, None, None, None, None, None

    # get the measured baseline and amplitude of psp
    measured_relative_amp, measured_baseline = measure_amp(
        avg_psp.data, [0, int((time_before_spike - 1.e-3) / avg_psp.dt)],
        [int((time_before_spike + .5e-3) / avg_psp.dt), -1])

    return pulse_responses, pulse_ids, psp_amps_measured, freq, avg_psp, measured_relative_amp, measured_baseline
Esempio n. 5
0
 def plot_element_data(self, pre_class, post_class, element, field_name, color='g', trace_plt=None):
     trace_plt = None
     val = element[field_name].mean()
     line = pg.InfiniteLine(val, pen={'color': color, 'width': 2}, movable=False)
     scatter = None
     baseline_window = int(db.default_sample_rate * 5e-3)
     values = []
     traces = []
     point_data = []
     for pair, value in element[field_name].iteritems():
         if np.isnan(value):
             continue
         traces = []
         if trace_plt is not None:
             if rsf is not None:
                 trace = rsf.ic_avg_data
                 start_time = rsf.ic_avg_data_start_time
                 latency = pair.synapse.latency
                 if latency is not None and start_time is not None:
                     xoffset = start_time - latency
                     trace = format_trace(trace, baseline_window, x_offset=xoffset, align='psp')
                     trace_plt.plot(trace.time_values, trace.data)
                     traces.append(trace)
         values.append(value)
         point_data.append(pair)
         y_values = pg.pseudoScatter(np.asarray(values, dtype=float), spacing=1)
         scatter = pg.ScatterPlotItem(symbol='o', brush=(color + (150,)), pen='w', size=12)
         scatter.setData(values, y_values + 10., data=point_data)
     for point in scatter.points():
         pair_id = point.data().id
         self.pair_items[pair_id] = [point, color]
     scatter.sigClicked.connect(self.scatter_plot_clicked)
     if len(traces) > 0:
         grand_trace = TSeriesList(traces).mean()
         trace_plt.plot(grand_trace.time_values, grand_trace.data, pen={'color': color, 'width': 3})
         units = 'V' if field_name.startswith('ic') else 'A'
         trace_plt.setXRange(0, 20e-3)
         trace_plt.setLabels(left=('', units), bottom=('Time from stimulus', 's'))
     return line, scatter
def trace_avg(response_list):
    # doc string commented out to discourage code reuse given the change of values of t0
    #    """
    #    Parameters
    #    ----------
    #    response_list : list of neuroanalysis.data.TSeriesView objects
    #        neuroanalysis.data.TSeriesView object contains waveform data.
    #
    #    Returns
    #    -------
    #    bsub_mean : neuroanalysis.data.TSeries object
    #        averages and baseline subtracts the ephys waveform data in the
    #        input response_list TSeriesView objects and replaces the .t0 value with 0.
    #
    #    """
    for trace in response_list:
        trace.t0 = 0  #align traces for the use of TSeriesList().mean() funtion
    avg_trace = TSeriesList(response_list).mean(
    )  #returns the average of the wave form in a of a neuroanalysis.data.TSeries object
    bsub_mean = bsub(
        avg_trace
    )  #returns a copy of avg_trace but replaces the ephys waveform in .data with the base_line subtracted wave_form

    return bsub_mean
 def get_tseries(self,
                 series,
                 bsub=True,
                 align='stim',
                 bsub_window=(-3e-3, 0)):
     """Return a TSeriesList of timeseries, optionally baseline-subtracted and time-aligned.
     
     Parameters
     ----------
     series : str
         "stim", "pre", or "post"
     """
     assert series in (
         'stim', 'pre',
         'post'), "series must be one of 'stim', 'pre', or 'post'"
     tseries = []
     for i, sr in enumerate(self.srs):
         ts = getattr(sr, series + '_tseries')
         if bsub:
             bstart = sr.stim_pulse.onset_time + bsub_window[0]
             bstop = sr.stim_pulse.onset_time + bsub_window[1]
             baseline = np.median(ts.time_slice(bstart, bstop).data)
             ts = ts - baseline
         if align is not None:
             if align == 'stim':
                 t_align = sr.stim_pulse.onset_time
             elif align == 'pre':
                 t_align = sr.stim_pulse.spikes[0].max_dvdt_time
             elif align == 'post':
                 raise NotImplementedError()
             else:
                 raise ValueError("invalid time alignment mode %r" % align)
             t_align = t_align or 0
             ts = ts.copy(t0=ts.t0 - t_align)
         tseries.append(ts)
     return TSeriesList(tseries)
Esempio n. 8
0
                                train_response,
                                freqs,
                                holding,
                                thresh=sweep_threshold,
                                ind_dict=grand_induction,
                                offset_dict=offset_ind)
                            grand_recovery, offset_rec = recovery_summary(
                                train_response,
                                t_rec,
                                holding,
                                thresh=sweep_threshold,
                                rec_dict=grand_recovery,
                                offset_dict=offset_rec)

        if len(grand_pulse_response) > 0:
            grand_pulse_trace = TSeriesList(grand_pulse_response).mean()
            p2 = trace_plot(grand_pulse_trace,
                            color=avg_color,
                            plot=p2,
                            x_range=[0, 27e-3],
                            name=('n = %d' % len(grand_pulse_response)))
            if len(grand_induction) > 0:
                for f, freq in enumerate(freqs):
                    if freq in grand_induction:
                        offset = offset_ind[freq]
                        ind_pass_qc = train_qc(grand_induction[freq],
                                               offset,
                                               amp=amp_thresh,
                                               sign=sign)
                        n = len(ind_pass_qc[0])
                        if n > 0:
Esempio n. 9
0
def analyze_pair_connectivity(amps, sign=None):
    """Given response strength records for a single pair, generate summary
    statistics characterizing strength, latency, and connectivity.
    
    Parameters
    ----------
    amps : dict
        Contains foreground and background strength analysis records
        (see input format below)
    sign : None, -1, or +1
        If None, then automatically determine whether to treat this connection as
        inhibitory or excitatory.

    Input must have the following structure::
    
        amps = {
            ('ic'): recs, 
            ('vc'): recs, 
        }
        
    Where each *recs* must be a structured array containing fields as returned
    by get_amps().
    
    The overall strategy here is:
    
    1. Make an initial decision on whether to treat this pair as excitatory or
       inhibitory, based on differences between foreground and background amplitude
       measurements
    2. Generate mean and stdev for amplitudes, deconvolved amplitudes, and deconvolved
       latencies
    3. Generate KS test p values describing the differences between foreground
       and background distributions for amplitude, deconvolved amplitude, and
       deconvolved latency    
    """
    # Filter by QC
    for k, v in amps.items():
        mask = v['qc_pass'].astype(bool)
        amps[k] = v[mask]

    # See if any data remains
    if all([len(a) == 0 for a in amps]):
        return None

    requested_sign = sign
    fields = {}  # used to fill the new DB record

    # Use KS p value to check for differences between foreground and background
    qc_amps = {}
    ks_pvals = {}
    amp_means = {}
    amp_diffs = {}
    for clamp_mode in ('ic', 'vc'):
        clamp_mode_amps = amps[clamp_mode]
        if len(clamp_mode_amps) == 0:
            continue
        for sign in ('pos', 'neg'):
            # Separate into positive/negative tests and filter out responses that failed qc
            qc_field = {
                'vc': {
                    'pos': 'in_qc_pass',
                    'neg': 'ex_qc_pass'
                },
                'ic': {
                    'pos': 'ex_qc_pass',
                    'neg': 'in_qc_pass'
                }
            }[clamp_mode][sign]
            fg = clamp_mode_amps[clamp_mode_amps[qc_field]]
            qc_amps[sign, clamp_mode] = fg
            if len(fg) == 0:
                continue

            # Measure some statistics from these records
            bg = fg['baseline_' + sign + '_dec_amp']
            fg = fg[sign + '_dec_amp']
            pval = scipy.stats.ks_2samp(fg, bg).pvalue
            ks_pvals[(sign, clamp_mode)] = pval
            # we could ensure that the average amplitude is in the right direction:
            fg_mean = np.mean(fg)
            bg_mean = np.mean(bg)
            amp_means[sign, clamp_mode] = {'fg': fg_mean, 'bg': bg_mean}
            amp_diffs[sign, clamp_mode] = fg_mean - bg_mean

    if requested_sign is None:
        # Decide whether to treat this connection as excitatory or inhibitory.
        #   strategy: accumulate evidence for either possibility by checking
        #   the ks p-values for each sign/clamp mode and the direction of the deflection
        is_exc = 0
        # print(expt.acq_timestamp, pair.pre_cell.ext_id, pair.post_cell.ext_id)
        for sign in ('pos', 'neg'):
            for mode in ('ic', 'vc'):
                ks = ks_pvals.get((sign, mode), None)
                if ks is None:
                    continue
                # turn p value into a reasonable scale factor
                ks = norm_pvalue(ks)
                dif_sign = 1 if amp_diffs[sign, mode] > 0 else -1
                if mode == 'vc':
                    dif_sign *= -1
                is_exc += dif_sign * ks
                # print("    ", sign, mode, is_exc, dif_sign * ks)
    else:
        is_exc = requested_sign

    if is_exc > 0:
        fields['synapse_type'] = 'ex'
        signs = {'ic': 'pos', 'vc': 'neg'}
    else:
        fields['synapse_type'] = 'in'
        signs = {'ic': 'neg', 'vc': 'pos'}

    # compute the rest of statistics for only positive or negative deflections
    for clamp_mode in ('ic', 'vc'):
        sign = signs[clamp_mode]
        fg = qc_amps.get((sign, clamp_mode))
        if fg is None or len(fg) == 0:
            fields[clamp_mode + '_n_samples'] = 0
            continue

        fields[clamp_mode + '_n_samples'] = len(fg)
        fields[clamp_mode + '_crosstalk_mean'] = np.mean(fg['crosstalk'])
        fields[clamp_mode + '_base_crosstalk_mean'] = np.mean(
            fg['baseline_crosstalk'])

        # measure mean, stdev, and statistical differences between
        # fg and bg for each measurement
        for val, field in [('amp', 'amp'), ('deconv_amp', 'dec_amp'),
                           ('latency', 'dec_latency')]:
            f = fg[sign + '_' + field]
            b = fg['baseline_' + sign + '_' + field]
            fields[clamp_mode + '_' + val + '_mean'] = np.mean(f)
            fields[clamp_mode + '_' + val + '_stdev'] = np.std(f)
            fields[clamp_mode + '_base_' + val + '_mean'] = np.mean(b)
            fields[clamp_mode + '_base_' + val + '_stdev'] = np.std(b)
            # statistical tests comparing fg vs bg
            # Note: we use log(1-log(pval)) because it's nicer to plot and easier to
            # use as a classifier input
            tt_pval = scipy.stats.ttest_ind(f, b, equal_var=False).pvalue
            ks_pval = scipy.stats.ks_2samp(f, b).pvalue
            fields[clamp_mode + '_' + val + '_ttest'] = norm_pvalue(tt_pval)
            fields[clamp_mode + '_' + val + '_ks2samp'] = norm_pvalue(ks_pval)

        ### generate the average response and psp fit

        # collect all bg and fg traces
        fg_traces = TSeriesList()
        for rec in fg:
            if not np.isfinite(
                    rec['max_slope_time']) or rec['max_slope_time'] is None:
                continue
            t0 = rec['response_start_time'] - rec[
                'max_slope_time']  # time-align to presynaptic spike
            trace = TSeries(rec['data'],
                            sample_rate=db.default_sample_rate,
                            t0=t0)
            fg_traces.append(trace)

        # get averages

        if len(fg_traces) == 0:
            continue

        # bg_avg = bg_traces.mean()
        fg_avg = fg_traces.mean()
        base_rgn = fg_avg.time_slice(-6e-3, 0)
        base = float_mode(base_rgn.data)
        fields[clamp_mode + '_average_response'] = fg_avg.data
        fields[clamp_mode + '_average_response_t0'] = fg_avg.t0
        fields[clamp_mode + '_average_base_stdev'] = base_rgn.std()

        sign = {'pos': 1, 'neg': -1}[signs[clamp_mode]]
        fg_bsub = fg_avg.copy(data=fg_avg.data -
                              base)  # remove base to help fitting
        try:
            fit = fit_psp(fg_bsub,
                          clamp_mode=clamp_mode,
                          sign=sign,
                          search_window=[0, 6e-3])
            for param, val in fit.best_values.items():
                fields['%s_fit_%s' % (clamp_mode, param)] = val
            fields[clamp_mode +
                   '_fit_yoffset'] = fit.best_values['yoffset'] + base
            fields[clamp_mode + '_fit_nrmse'] = fit.nrmse()
        except:
            print("Error in PSP fit:")
            sys.excepthook(*sys.exc_info())
            continue

        #global fit_plot
        #if fit_plot is None:
        #fit_plot = FitExplorer(fit)
        #fit_plot.show()
        #else:
        #fit_plot.set_fit(fit)
        #raw_input("Waiting to continue..")

    return fields
Esempio n. 10
0
    def add_connection_plots(i, name, timestamp, pre_id, post_id):
        global session, win, filtered
        p = pg.debug.Profiler(disabled=True, delayed=False)
        trace_plot = win.addPlot(i, 1)
        trace_plots.append(trace_plot)
        trace_plot.setYRange(-1.4e-3, 2.1e-3)
        # deconv_plot = win.addPlot(i, 2)
        # deconv_plots.append(deconv_plot)
        # deconv_plot.hide()

        hist_plot = win.addPlot(i, 2)
        hist_plots.append(hist_plot)
        limit_plot = win.addPlot(i, 3)
        limit_plot.addLegend()
        limit_plot.setLogMode(True, False)
        limit_plot.addLine(y=classifier.prob_threshold)

        # Find this connection in the pair list
        idx = np.argwhere((abs(filtered['acq_timestamp'] - timestamp) < 1)
                          & (filtered['pre_cell_id'] == pre_id)
                          & (filtered['post_cell_id'] == post_id))
        if idx.size == 0:
            print("not in filtered connections")
            return
        idx = idx[0, 0]
        p()

        # Mark the point in scatter plot
        scatter_plot.plot([background[idx]], [signal[idx]],
                          pen='k',
                          symbol='o',
                          size=10,
                          symbolBrush='r',
                          symbolPen=None)

        # Plot example traces and histograms
        for plts in [trace_plots]:  #, deconv_plots]:
            plt = plts[-1]
            plt.setXLink(plts[0])
            plt.setYLink(plts[0])
            plt.setXRange(-10e-3, 17e-3, padding=0)
            plt.hideAxis('left')
            plt.hideAxis('bottom')
            plt.addLine(x=0)
            plt.setDownsampling(auto=True, mode='peak')
            plt.setClipToView(True)
            hbar = pg.QtGui.QGraphicsLineItem(0, 0, 2e-3, 0)
            hbar.setPen(pg.mkPen(color='k', width=5))
            plt.addItem(hbar)
            vbar = pg.QtGui.QGraphicsLineItem(0, 0, 0, 100e-6)
            vbar.setPen(pg.mkPen(color='k', width=5))
            plt.addItem(vbar)

        hist_plot.setXLink(hist_plots[0])

        pair = session.query(
            db.Pair).filter(db.Pair.id == filtered[idx]['pair_id']).all()[0]
        p()
        amps = strength_analysis.get_amps(session, pair)
        p()
        base_amps = strength_analysis.get_baseline_amps(session,
                                                        pair,
                                                        amps=amps,
                                                        clamp_mode='ic')
        p()

        q = strength_analysis.response_query(session)
        p()
        q = q.join(strength_analysis.PulseResponseStrength)
        q = q.filter(strength_analysis.PulseResponseStrength.id.in_(
            amps['id']))
        q = q.join(db.MultiPatchProbe)
        q = q.filter(db.MultiPatchProbe.induction_frequency < 100)
        # pre_cell = db.aliased(db.Cell)
        # post_cell = db.aliased(db.Cell)
        # q = q.join(db.Pair).join(db.Experiment).join(pre_cell, db.Pair.pre_cell_id==pre_cell.id).join(post_cell, db.Pair.post_cell_id==post_cell.id)
        # q = q.filter(db.Experiment.id==filtered[idx]['experiment_id'])
        # q = q.filter(pre_cell.ext_id==pre_id)
        # q = q.filter(post_cell.ext_id==post_id)

        fg_recs = q.all()
        p()

        traces = []
        deconvs = []
        for i, rec in enumerate(fg_recs[:100]):
            result = strength_analysis.analyze_response_strength(
                rec,
                source='pulse_response',
                lpf=True,
                lowpass=2000,
                remove_artifacts=False,
                bsub=True)
            trace = result['raw_trace']
            trace.t0 = -result['spike_time']
            trace = trace - np.median(trace.time_slice(-0.5e-3, 0.5e-3).data)
            traces.append(trace)
            trace_plot.plot(trace.time_values, trace.data, pen=(0, 0, 0, 20))
            write_csv(
                csv_file, trace,
                "Figure 3B; {name}; trace {trace_n}".format(name=name,
                                                            trace_n=i))

            # trace = result['dec_trace']
            # trace.t0 = -result['spike_time']
            # trace = trace - np.median(trace.time_slice(-0.5e-3, 0.5e-3).data)
            # deconvs.append(trace)
            # # deconv_plot.plot(trace.time_values, trace.data, pen=(0, 0, 0, 20))

        # plot average trace
        mean = TSeriesList(traces).mean()
        trace_plot.plot(mean.time_values,
                        mean.data,
                        pen={
                            'color': 'g',
                            'width': 2
                        },
                        shadowPen={
                            'color': 'k',
                            'width': 3
                        },
                        antialias=True)
        write_csv(csv_file, mean,
                  "Figure 3B; {name}; average".format(name=name))
        # mean = TSeriesList(deconvs).mean()
        # # deconv_plot.plot(mean.time_values, mean.data, pen={'color':'g', 'width': 2}, shadowPen={'color':'k', 'width': 3}, antialias=True)

        # add label
        label = pg.LabelItem(name)
        label.setParentItem(trace_plot)

        p("analyze_response_strength")

        # bins = np.arange(-0.0005, 0.002, 0.0001)
        # field = 'pos_amp'
        bins = np.arange(-0.001, 0.015, 0.0005)
        field = 'pos_dec_amp'
        n = min(len(amps), len(base_amps))
        hist_y, hist_bins = np.histogram(base_amps[:n][field], bins=bins)
        hist_plot.plot(hist_bins,
                       hist_y,
                       stepMode=True,
                       pen=None,
                       brush=(200, 0, 0, 150),
                       fillLevel=0)
        write_csv(
            csv_file, hist_bins,
            "Figure 3C; {name}; background noise amplitude distribution bin edges (V)"
            .format(name=name))
        write_csv(
            csv_file, hist_y,
            "Figure 3C; {name}; background noise amplitude distribution counts per bin"
            .format(name=name))

        hist_y, hist_bins = np.histogram(amps[:n][field], bins=bins)
        hist_plot.plot(hist_bins,
                       hist_y,
                       stepMode=True,
                       pen='k',
                       brush=(0, 150, 150, 100),
                       fillLevel=0)
        write_csv(
            csv_file, hist_bins,
            "Figure 3C; {name}; PSP amplitude distribution bin edges (V)".
            format(name=name))
        write_csv(
            csv_file, hist_y,
            "Figure 3C; {name}; PSP amplitude distribution counts per bin".
            format(name=name))
        p()

        pg.QtGui.QApplication.processEvents()

        # Plot detectability analysis
        q = strength_analysis.baseline_query(session)
        q = q.join(strength_analysis.BaselineResponseStrength)
        q = q.filter(
            strength_analysis.BaselineResponseStrength.id.in_(base_amps['id']))
        # q = q.limit(100)
        bg_recs = q.all()

        def clicked(sp, pts):
            data = pts[0].data()
            print("-----------------------\nclicked:", data['rise_time'],
                  data['amp'], data['prediction'], data['confidence'])
            for r in data['results']:
                print({k: r[k] for k in classifier.features})
            traces = data['traces']
            plt = pg.plot()
            bsub = [
                t.copy(data=t.data - np.median(t.time_slice(0, 1e-3).data))
                for t in traces
            ]
            for t in bsub:
                plt.plot(t.time_values, t.data, pen=(0, 0, 0, 50))
            mean = TSeriesList(bsub).mean()
            plt.plot(mean.time_values, mean.data, pen='g')

        # def analyze_response_strength(recs, source, dtype):
        #     results = []
        #     for i,rec in enumerate(recs):
        #         result = strength_analysis.analyze_response_strength(rec, source)
        #         results.append(result)
        #     return str_analysis_result_table(results)

        # measure background connection strength
        bg_results = [
            strength_analysis.analyze_response_strength(rec, 'baseline')
            for rec in bg_recs
        ]
        bg_results = strength_analysis.str_analysis_result_table(
            bg_results, bg_recs)

        # for this example, we use background data to simulate foreground
        # (but this will be biased due to lack of crosstalk in background data)
        fg_recs = bg_recs

        # now measure foreground simulated under different conditions
        amps = 2e-6 * 2**np.arange(9)
        amps[0] = 0
        rtimes = [1e-3, 2e-3, 4e-3, 6e-3]
        dt = 1 / db.default_sample_rate
        results = np.empty((len(amps), len(rtimes)),
                           dtype=[('results', object), ('predictions', object),
                                  ('confidence', object), ('traces', object),
                                  ('rise_time', float), ('amp', float)])
        print("  Simulating synaptic events..")

        cachefile = 'fig_3_cache.pkl'
        if os.path.exists(cachefile):
            cache = pickle.load(open(cachefile, 'rb'))
        else:
            cache = {}
        pair_key = (timestamp, pre_id, post_id)
        pair_cache = cache.setdefault(pair_key, {})

        for j, rtime in enumerate(rtimes):
            new_results = False
            for i, amp in enumerate(amps):
                print(
                    "---------------------------------------    %d/%d  %d/%d      \r"
                    % (i, len(amps), j, len(rtimes)), )
                result = pair_cache.get((rtime, amp))
                if result is None:
                    result = strength_analysis.simulate_connection(
                        fg_recs, bg_results, classifier, amp, rtime)
                    pair_cache[rtime, amp] = result
                    new_results = True

                for k, v in result.items():
                    results[i, j][k] = v

            x, y = amps, [np.mean(x) for x in results[:, j]['confidence']]
            c = limit_plot.plot(x,
                                y,
                                pen=pg.intColor(j,
                                                len(rtimes) * 1.3,
                                                maxValue=150),
                                symbol='o',
                                antialias=True,
                                name="%dus" % (rtime * 1e6),
                                data=results[:, j],
                                symbolSize=4)
            write_csv(
                csv_file, x,
                "Figure 3D; {name}; {rise_time:0.3g} ms rise time; simulated PSP amplitude (V)"
                .format(name=name, rise_time=rtime * 1000))
            write_csv(
                csv_file, y,
                "Figure 3D; {name}; {rise_time:0.3g} ms rise time; classifier decision probability"
                .format(name=name, rise_time=rtime * 1000))
            c.scatter.sigClicked.connect(clicked)
            pg.QtGui.QApplication.processEvents()

            if new_results:
                pickle.dump(cache, open(cachefile, 'wb'))

        pg.QtGui.QApplication.processEvents()
Esempio n. 11
0
 def plot_element_data(self, pre_class, post_class, element, field_name, color='g', trace_plt=None):
     val = element[field_name].mean()
     line = pg.InfiniteLine(val, pen={'color': color, 'width': 2}, movable=False)
     scatter = None
     baseline_window = int(db.default_sample_rate * 5e-3)
     values = []
     tracesA = []
     tracesB = []
     point_data = []
     for pair, value in element[field_name].iteritems():
         latency = self.results.loc[pair]['Latency']
         trace_itemA = None
         trace_itemB = None
         if pair.has_synapse is not True:
             continue
         if np.isnan(value):
             continue
         syn_typ = pair.synapse.synapse_type
         rsf = pair.resting_state_fit
         if rsf is not None:
             nrmse = rsf.vc_nrmse if field_name.startswith('PSC') else rsf.ic_nrmse
             # if nrmse is None or nrmse > 0.8:
             #     continue
             data = rsf.vc_avg_data if field_name.startswith('PSC') else rsf.ic_avg_data
             traceA = TSeries(data=data, sample_rate=db.default_sample_rate)
             if field_name.startswith('PSC'):
                 traceA = bessel_filter(traceA, 5000, btype='low', bidir=True)
             bessel_filter(traceA, 5000, btype='low', bidir=True)
             start_time = rsf.vc_avg_data_start_time if field_name.startswith('PSC') else rsf.ic_avg_data_start_time
             if latency is not None and start_time is not None:
                 if field_name == 'Latency':
                     xoffset = start_time + latency
                 else:
                     xoffset = start_time - latency
                 baseline_window = [abs(xoffset)-1e-3, abs(xoffset)]
                 traceA = format_trace(traceA, baseline_window, x_offset=xoffset, align='psp')
                 trace_itemA = trace_plt[1].plot(traceA.time_values, traceA.data)
                 trace_itemA.pair = pair
                 trace_itemA.curve.setClickable(True)
                 trace_itemA.sigClicked.connect(self.trace_plot_clicked)
                 tracesA.append(traceA)
             if field_name == 'Latency' and rsf.vc_nrmse is not None: #and rsf.vc_nrmse < 0.8:
                 traceB = TSeries(data=rsf.vc_avg_data, sample_rate=db.default_sample_rate)
                 traceB = bessel_filter(traceB, 5000, btype='low', bidir=True)
                 start_time = rsf.vc_avg_data_start_time
                 if latency is not None and start_time is not None:
                     xoffset = start_time + latency
                     baseline_window = [abs(xoffset)-1e-3, abs(xoffset)]
                     traceB = format_trace(traceB, baseline_window, x_offset=xoffset, align='psp')
                     trace_itemB = trace_plt[0].plot(traceB.time_values, traceB.data)
                     trace_itemB.pair = pair
                     trace_itemB.curve.setClickable(True)
                     trace_itemB.sigClicked.connect(self.trace_plot_clicked)
                     tracesB.append(traceB)
         self.pair_items[pair.id] = [trace_itemA, trace_itemB]
         if trace_itemA is not None:
             values.append(value)
             point_data.append(pair)
     y_values = pg.pseudoScatter(np.asarray(values, dtype=float), spacing=1)
     scatter = pg.ScatterPlotItem(symbol='o', brush=(color + (150,)), pen='w', size=12)
     scatter.setData(values, y_values + 10., data=point_data)
     for point in scatter.points():
         pair_id = point.data().id
         self.pair_items[pair_id].extend([point, color])
     scatter.sigClicked.connect(self.scatter_plot_clicked)
     if len(tracesA) > 0:
         if field_name == 'Latency':     
             spike_line = pg.InfiniteLine(0, pen={'color': 'w', 'width': 1, 'style': pg.QtCore.Qt.DotLine}, movable=False)
             trace_plt[0].addItem(spike_line)
             x_label = 'Time from presynaptic spike'
         else:
             x_label = 'Response Onset'
         grand_trace = TSeriesList(tracesA).mean()
         name = ('%s->%s, n=%d' % (pre_class, post_class, len(tracesA)))
         trace_plt[1].plot(grand_trace.time_values, grand_trace.data, pen={'color': color, 'width': 3}, name=name)
         units = 'A' if field_name.startswith('PSC') else 'V'
         title = 'Voltage Clamp' if field_name.startswith('PSC') else 'Current Clamp'
         trace_plt[1].setXRange(-5e-3, 20e-3)
         trace_plt[1].setLabels(left=('', units), bottom=(x_label, 's'))
         trace_plt[1].setTitle(title)
     if len(tracesB) > 0:
         trace_plt[1].setLabels(right=('', units))
         trace_plt[1].hideAxis('left')
         spike_line = pg.InfiniteLine(0, pen={'color': 'w', 'width': 1, 'style': pg.QtCore.Qt.DotLine}, movable=False)
         trace_plt[0].addItem(spike_line)
         grand_trace = TSeriesList(tracesB).mean()
         trace_plt[0].plot(grand_trace.time_values, grand_trace.data, pen={'color': color, 'width': 3})
         trace_plt[0].setXRange(-5e-3, 20e-3)
         trace_plt[0].setLabels(left=('', 'A'), bottom=('Time from presynaptic spike', 's'))
         trace_plt[0].setTitle('Voltage Clamp')
     return line, scatter
Esempio n. 12
0
                    offset_dict=pulse_offset_rec,
                    uid=(expt.uid, pre, post))
    for f, freq in enumerate(freqs):
        if freq not in induction_grand.keys():
            print("%d Hz not represented in data set for %s" % (freq, c_type))
            continue
        ind_offsets = pulse_offset_ind[freq]
        qc_plot.clear()
        ind_pass_qc = train_qc(induction_grand[freq],
                               ind_offsets,
                               amp=qc_params[1][c],
                               sign=qc_params[0],
                               plot=qc_plot)
        n_synapses = len(ind_pass_qc[0])
        if n_synapses > 0:
            induction_grand_trace = TSeriesList(ind_pass_qc[0]).mean()
            ind_rec_grand_trace = TSeriesList(ind_pass_qc[1]).mean()
            ind_amp = train_amp(ind_pass_qc, ind_offsets, '+')
            ind_amp_grand = np.nanmean(ind_amp, 0)

            if f == 0:
                ind_plot[f, c].setTitle(connection_types[c])
                type = pg.LabelItem('%s -> %s' % connection_types[c])
                type.setParentItem(summary_plot[c, 0])
                type.setPos(50, 0)
            if c == 0:
                label = pg.LabelItem('%d Hz Induction' % freq)
                label.setParentItem(ind_plot[f, c].vb)
                label.setPos(50, 0)
                summary_plot[c, 0].setTitle('Induction')
            ind_plot[f, c].addLegend()
Esempio n. 13
0
    def plot_prd_ids(self,
                     ids,
                     source,
                     pen=None,
                     trace_list=None,
                     avg=False,
                     qc_filter=None):
        """Plot raw or decolvolved PulseResponse data, given IDs of records in
        a db.PulseResponseStrength table.
        """
        if qc_filter is None:
            qc_filter = self.ui.qc_check.isChecked()

        with pg.BusyCursor():
            recs = self.get_pulse_recs(ids, source)
            if len(recs) == 0:
                return

            if source == 'fg':
                traces = self.selected_fg_traces
                plot = self.fg_trace_plot
            else:
                traces = self.selected_bg_traces
                plot = self.bg_trace_plot

            for i in trace_list[:]:
                plot.removeItem(i)
                trace_list.remove(i)

            if pen is None:
                alpha = np.clip(1000 / len(recs), 30, 255)
                pen = (255, 255, 255, alpha)

            pen = pg.mkPen(pen)
            # qc-failed traces are tinted red
            fail_color = pen.color()
            fail_color.setBlue(fail_color.blue() // 2)
            fail_color.setGreen(fail_color.green() // 2)
            qc_fail_pen = pg.mkPen(fail_color)

            traces = []
            spike_times = []
            spike_values = []
            for rec in recs:
                # Filter by QC unless we selected just a single record
                qc_pass = getattr(rec, self.qc_field) is True
                if qc_filter is True and not qc_pass:
                    continue

                s = {'fg': 'pulse_response', 'bg': 'baseline'}[source]
                filter_opts = dict(
                    deconvolve=self.ui.deconv_check.isChecked(),
                    lpf=self.ui.lpf_check.isChecked(),
                    remove_artifacts=self.ui.ar_check.isChecked(),
                    bsub=self.ui.bsub_check.isChecked(),
                )
                result = analyze_response_strength(rec,
                                                   source=s,
                                                   **filter_opts)
                trace = result['dec_trace']

                spike_values.append(trace.value_at([result['spike_time']])[0])
                if self.ui.align_check.isChecked():
                    trace.t0 = -result['spike_time']
                    spike_times.append(0)
                else:
                    spike_times.append(result['spike_time'])

                traces.append(trace)
                trace_list.append(
                    plot.plot(trace.time_values,
                              trace.data,
                              pen=(pen if qc_pass else qc_fail_pen)))

            if avg and len(traces) > 0:
                mean = TSeriesList(traces).mean()
                trace_list.append(
                    plot.plot(mean.time_values, mean.data, pen='g'))
                trace_list[-1].setZValue(10)

            spike_scatter = pg.ScatterPlotItem(spike_times,
                                               spike_values,
                                               size=4,
                                               pen=None,
                                               brush=(200, 200, 0))
            spike_scatter.setZValue(-100)
            plot.addItem(spike_scatter)
            trace_list.append(spike_scatter)
Esempio n. 14
0
def first_pulse_plot(expt_list,
                     name=None,
                     summary_plot=None,
                     color=None,
                     scatter=0,
                     features=False):
    amp_plots = pg.plot()
    amp_plots.setLabels(left=('Vm', 'V'))
    grand_response = []
    avg_amps = {'amp': [], 'latency': [], 'rise': []}
    for expt in expt_list:
        if expt.connections is not None:
            for pre, post in expt.connections:
                if expt.cells[pre].cre_type == cre_type[0] and expt.cells[
                        post].cre_type == cre_type[1]:
                    all_responses, artifact = get_response(
                        expt, pre, post, analysis_type='pulse')
                    if artifact > 0.03e-3:
                        continue
                    filtered_responses = response_filter(
                        all_responses,
                        freq_range=[0, 50],
                        holding_range=[-68, -72],
                        pulse=True)
                    n_sweeps = len(filtered_responses)
                    if n_sweeps >= 10:
                        avg_trace, avg_amp, amp_sign, _ = get_amplitude(
                            filtered_responses)
                        if expt.cells[
                                pre].cre_type in EXCITATORY_CRE_TYPES and avg_amp < 0:
                            continue
                        elif expt.cells[
                                pre].cre_type in INHIBITORY_CRE_TYPES and avg_amp > 0:
                            continue
                        avg_trace.t0 = 0
                        avg_amps['amp'].append(avg_amp)
                        grand_response.append(avg_trace)
                        if features is True:
                            psp_fits = fit_psp(avg_trace,
                                               sign=amp_sign,
                                               yoffset=0,
                                               amp=avg_amp,
                                               method='leastsq',
                                               fit_kws={})
                            avg_amps['latency'].append(
                                psp_fits.best_values['xoffset'] - 10e-3)
                            avg_amps['rise'].append(
                                psp_fits.best_values['rise_time'])

                        current_connection_HS = post, pre
                        if len(expt.connections) > 1 and args.recip is True:
                            for i, x in enumerate(expt.connections):
                                if x == current_connection_HS:  # determine if a reciprocal connection
                                    amp_plots.plot(avg_trace.time_values,
                                                   avg_trace.data,
                                                   pen={
                                                       'color': 'r',
                                                       'width': 1
                                                   })
                                    break
                                elif x != current_connection_HS and i == len(
                                        expt.connections
                                ) - 1:  # reciprocal connection was not found
                                    amp_plots.plot(avg_trace.time_values,
                                                   avg_trace.data)
                        else:
                            amp_plots.plot(avg_trace.time_values,
                                           avg_trace.data)

                        app.processEvents()

    if len(grand_response) != 0:
        print(name + ' n = %d' % len(grand_response))
        grand_mean = TSeriesList(grand_response).mean()
        grand_amp = np.mean(np.array(avg_amps['amp']))
        grand_amp_sem = stats.sem(np.array(avg_amps['amp']))
        amp_plots.addLegend()
        amp_plots.plot(grand_mean.time_values,
                       grand_mean.data,
                       pen={
                           'color': 'g',
                           'width': 3
                       },
                       name=name)
        amp_plots.addLine(y=grand_amp, pen={'color': 'g'})
        if grand_mean is not None:
            print(legend + ' Grand mean amplitude = %f +- %f' %
                  (grand_amp, grand_amp_sem))
            if features is True:
                feature_list = (avg_amps['amp'], avg_amps['latency'],
                                avg_amps['rise'])
                labels = (['Vm', 'V'], ['t', 's'], ['t', 's'])
                titles = ('Amplitude', 'Latency', 'Rise time')
            else:
                feature_list = [avg_amps['amp']]
                labels = (['Vm', 'V'])
                titles = 'Amplitude'
            summary_plots = summary_plot_pulse(feature_list[0],
                                               labels=labels,
                                               titles=titles,
                                               i=scatter,
                                               grand_trace=grand_mean,
                                               plot=summary_plot,
                                               color=color,
                                               name=legend)
            return avg_amps, summary_plots
    else:
        print("No TSeries")
        return avg_amps, None
Esempio n. 15
0
def train_response_plot(expt_list,
                        name=None,
                        summary_plots=[None, None],
                        color=None):
    grand_train = [[], []]
    train_plots = pg.plot()
    train_plots.setLabels(left=('Vm', 'V'))
    tau = 15e-3
    lp = 1000
    for expt in expt_list:
        for pre, post in expt.connections:
            if expt.cells[pre].cre_type == cre_type[0] and expt.cells[
                    post].cre_type == cre_type[1]:
                print('Processing experiment: %s' % (expt.nwb_file))

                train_responses, artifact = get_response(expt,
                                                         pre,
                                                         post,
                                                         analysis_type='train')
                if artifact > 0.03e-3:
                    continue

                train_filter = response_filter(train_responses['responses'],
                                               freq_range=[50, 50],
                                               train=0,
                                               delta_t=250)
                pulse_offsets = response_filter(
                    train_responses['pulse_offsets'],
                    freq_range=[50, 50],
                    train=0,
                    delta_t=250)

                if len(train_filter[0]) > 5:
                    ind_avg = TSeriesList(train_filter[0]).mean()
                    rec_avg = TSeriesList(train_filter[1]).mean()
                    rec_avg.t0 = 0.3
                    grand_train[0].append(ind_avg)
                    grand_train[1].append(rec_avg)
                    train_plots.plot(ind_avg.time_values, ind_avg.data)
                    train_plots.plot(rec_avg.time_values, rec_avg.data)
                    app.processEvents()
    if len(grand_train[0]) != 0:
        print(name + ' n = %d' % len(grand_train[0]))
        ind_grand_mean = TSeriesList(grand_train[0]).mean()
        rec_grand_mean = TSeriesList(grand_train[1]).mean()
        ind_grand_mean_dec = bessel_filter(exp_deconvolve(ind_grand_mean, tau),
                                           lp)
        train_plots.addLegend()
        train_plots.plot(ind_grand_mean.time_values,
                         ind_grand_mean.data,
                         pen={
                             'color': 'g',
                             'width': 3
                         },
                         name=name)
        train_plots.plot(rec_grand_mean.time_values,
                         rec_grand_mean.data,
                         pen={
                             'color': 'g',
                             'width': 3
                         },
                         name=name)
        train_amps = train_amp([grand_train[0], grand_train[1]], pulse_offsets,
                               '+')
        if ind_grand_mean is not None:
            train_plots = summary_plot_train(ind_grand_mean,
                                             plot=summary_plots[0],
                                             color=color,
                                             name=(legend +
                                                   ' 50 Hz induction'))
            train_plots = summary_plot_train(rec_grand_mean,
                                             plot=summary_plots[0],
                                             color=color)
            train_plots2 = summary_plot_train(ind_grand_mean_dec,
                                              plot=summary_plots[1],
                                              color=color,
                                              name=(legend +
                                                    ' 50 Hz induction'))
            return train_plots, train_plots2, train_amps
    else:
        print("No TSeries")
        return None
Esempio n. 16
0
    def _get_tserieslist(self,
                         ts_name,
                         align,
                         bsub,
                         bsub_win=5e-3,
                         alignment_failure_mode='ignore'):
        tsl = []
        if align is not None and alignment_failure_mode == 'average':
            if align == 'spike':
                average_align_t = np.mean([
                    p.stim_pulse.first_spike_time for p in self.prs
                    if p.stim_pulse.first_spike_time is not None
                ])
            elif align == 'peak':
                average_align_t = np.mean([
                    p.stim_pulse.spikes[0].peak_time for p in self.prs
                    if p.stim_pulse.n_spikes == 1
                    and p.stim_pulse.spikes[0].peak_time is not None
                ])
            elif align == 'stim':
                average_align_t = np.mean([
                    p.stim_pulse.onset_time for p in self.prs
                    if p.stim_pulse.onset_time is not None
                ])
            else:
                raise ValueError(
                    "align must be None, 'spike', 'peak', or 'pulse'.")

        for pr in self.prs:
            ts = getattr(pr, ts_name)
            stim_time = pr.stim_pulse.onset_time

            if bsub is True:
                start_time = max(ts.t0, stim_time - bsub_win)
                baseline_data = ts.time_slice(start_time, stim_time).data
                if len(baseline_data) == 0:
                    baseline = ts.data[0]
                else:
                    baseline = float_mode(baseline_data)
                ts = ts - baseline

            if align is not None:
                if align == 'spike':
                    # first_spike_time is the max dv/dt of the spike
                    align_t = pr.stim_pulse.first_spike_time
                elif align == 'pulse':
                    align_t = stim_time
                elif align == 'peak':
                    # peak of the first spike
                    align_t = pr.stim_pulse.spikes[
                        0].peak_time if pr.stim_pulse.n_spikes == 1 else None
                else:
                    raise ValueError(
                        "align must be None, 'spike', 'peak', or 'pulse'.")

                if align_t is None:
                    if alignment_failure_mode == 'ignore':
                        # ignore PRs with no known timing
                        continue
                    elif alignment_failure_mode == 'average':
                        align_t = average_align_t
                        if np.isnan(align_t):
                            raise Exception(
                                "average %s time is None, try another mode" %
                                align)
                    elif alignment_failure_mode == 'raise':
                        raise Exception(
                            "%s time is not available for pulse %s and can't be aligned"
                            % (align, pr))

                ts = ts.copy(t0=ts.t0 - align_t)

            tsl.append(ts)
        return TSeriesList(tsl)
Esempio n. 17
0
def plot_features(organism=None,
                  conn_type=None,
                  calcium=None,
                  age=None,
                  sweep_thresh=None,
                  fit_thresh=None):
    s = db.session()

    filters = {
        'organism': organism,
        'conn_type': conn_type,
        'calcium': calcium,
        'age': age
    }

    selection = [{}]
    for key, value in filters.iteritems():
        if value is not None:
            temp_list = []
            value_list = value.split(',')
            for v in value_list:
                temp = [s1.copy() for s1 in selection]
                for t in temp:
                    t[key] = v
                temp_list = temp_list + temp
            selection = list(temp_list)

    if len(selection) > 0:

        response_grid = PlotGrid()
        response_grid.set_shape(len(selection), 1)
        response_grid.show()
        feature_grid = PlotGrid()
        feature_grid.set_shape(6, 1)
        feature_grid.show()

        for i, select in enumerate(selection):
            pre_cell = aliased(db.Cell)
            post_cell = aliased(db.Cell)
            q_filter = []
            if sweep_thresh is not None:
                q_filter.append(FirstPulseFeatures.n_sweeps >= sweep_thresh)
            species = select.get('organism')
            if species is not None:
                q_filter.append(db.Slice.species == species)
            c_type = select.get('conn_type')
            if c_type is not None:
                pre_type, post_type = c_type.split('-')
                pre_layer, pre_cre = pre_type.split(';')
                if pre_layer == 'None':
                    pre_layer = None
                post_layer, post_cre = post_type.split(';')
                if post_layer == 'None':
                    post_layer = None
                q_filter.extend([
                    pre_cell.cre_type == pre_cre,
                    pre_cell.target_layer == pre_layer,
                    post_cell.cre_type == post_cre,
                    post_cell.target_layer == post_layer
                ])
            calc_conc = select.get('calcium')
            if calc_conc is not None:
                q_filter.append(db.Experiment.acsf.like(calc_conc + '%'))
            age_range = select.get('age')
            if age_range is not None:
                age_lower, age_upper = age_range.split('-')
                q_filter.append(
                    db.Slice.age.between(int(age_lower), int(age_upper)))

            q = s.query(FirstPulseFeatures).join(db.Pair, FirstPulseFeatures.pair_id==db.Pair.id)\
                .join(pre_cell, db.Pair.pre_cell_id==pre_cell.id)\
                .join(post_cell, db.Pair.post_cell_id==post_cell.id)\
                .join(db.Experiment, db.Experiment.id==db.Pair.expt_id)\
                .join(db.Slice, db.Slice.id==db.Experiment.slice_id)

            for filter_arg in q_filter:
                q = q.filter(filter_arg)

            results = q.all()

            trace_list = []
            for pair in results:
                #TODO set t0 to latency to align to foot of PSP
                trace = TSeries(data=pair.avg_psp,
                                sample_rate=db.default_sample_rate)
                trace_list.append(trace)
                response_grid[i, 0].plot(trace.time_values, trace.data)
            if len(trace_list) > 0:
                grand_trace = TSeriesList(trace_list).mean()
                response_grid[i, 0].plot(grand_trace.time_values,
                                         grand_trace.data,
                                         pen='b')
                response_grid[i, 0].setTitle(
                    'layer %s, %s-> layer %s, %s; n_synapses = %d' %
                    (pre_layer, pre_cre, post_layer, post_cre,
                     len(trace_list)))
            else:
                print('No synapses for layer %s, %s-> layer %s, %s' %
                      (pre_layer, pre_cre, post_layer, post_cre))

    return response_grid, feature_grid
Esempio n. 18
0
    def plot_element_data(self, pre_class, post_class, element, field_name, color='g', trace_plt=None):
        summary = element.agg(self.summary_stat)  
        val = summary[field_name]['metric_summary']
        line = pg.InfiniteLine(val, pen={'color': color, 'width': 2}, movable=False)
        scatter = None
        tracesA = []
        tracesB = []
        connections = element[element['Connected'] == True].index.tolist()
        for pair in connections:
            # rsf = pair.resting_state_fit
            synapse = pair.synapse
            if synapse is None:
                continue
            arfs = pair.avg_response_fits
            latency = pair.synapse.latency
            syn_typ = pair.synapse.synapse_type
            self.pair_items[pair.id] = []
            trace_itemA = None
            trace_itemB = None
            # if rsf is not None:
            #     traceA = TSeries(data=rsf.ic_avg_data, sample_rate=db.default_sample_rate)
            #     start_time = rsf.ic_avg_data_start_time
            #     if latency is not None and start_time is not None:
            #         xoffset = start_time - latency
            #         baseline_window = [abs(xoffset)-1e-3, abs(xoffset)]
            #         traceA = format_trace(traceA, baseline_window, x_offset=xoffset, align='psp')
            #         trace_itemA = trace_plt[0].plot(traceA.time_values, traceA.data)
            #         trace_itemA.pair = pair
            #         trace_itemA.curve.setClickable(True)
            #         trace_itemA.sigClicked.connect(self.trace_plot_clicked)
            #         self.pair_items[pair.id].append(trace_itemA)
            #         tracesA.append(traceA)
            if arfs is not None:
                for arf in arfs:
                    if arf.holding in syn_typ_holding[syn_typ] and arf.manual_qc_pass is True and latency is not None:
                        if arf.clamp_mode == 'vc' and trace_itemA is None:
                            traceA = TSeries(data=arf.avg_data, sample_rate=db.default_sample_rate)
                            traceA = bessel_filter(traceA, 5000, btype='low', bidir=True)
                            start_time = arf.avg_data_start_time
                            if start_time is not None:
                                xoffset = start_time - latency
                                baseline_window = [abs(xoffset)-1e-3, abs(xoffset)]
                                traceA = format_trace(traceA, baseline_window, x_offset=xoffset, align='psp')
                                trace_itemA = trace_plt[0].plot(traceA.time_values, traceA.data)
                                trace_itemA.pair = pair
                                trace_itemA.curve.setClickable(True)
                                trace_itemA.sigClicked.connect(self.trace_plot_clicked)
                                self.pair_items[pair.id].append(trace_itemA)
                                tracesA.append(traceA)
                        if arf.clamp_mode == 'ic' and trace_itemB is None:
                            traceB = TSeries(data=arf.avg_data, sample_rate=db.default_sample_rate)
                            start_time = arf.avg_data_start_time
                            if latency is not None and start_time is not None:
                                xoffset = start_time - latency
                                baseline_window = [abs(xoffset)-1e-3, abs(xoffset)]
                                traceB = format_trace(traceB, baseline_window, x_offset=xoffset, align='psp')
                                trace_itemB = trace_plt[1].plot(traceB.time_values, traceB.data)
                                trace_itemB.pair = pair
                                trace_itemB.curve.setClickable(True)
                                trace_itemB.sigClicked.connect(self.trace_plot_clicked)
                                tracesB.append(traceB)
            self.pair_items[pair.id] = [trace_itemA, trace_itemB]

        if len(tracesA) > 0:
            grand_trace = TSeriesList(tracesA).mean()
            name = ('%s->%s' % (pre_class, post_class))
            # trace_plt[0].addLegend()
            trace_plt[0].plot(grand_trace.time_values, grand_trace.data, pen={'color': color, 'width': 3}, name=name)
            trace_plt[0].setXRange(-5e-3, 20e-3)
            trace_plt[0].setLabels(left=('', 'A'), bottom=('Response Onset', 's'))
            trace_plt[0].setTitle('Voltage Clamp')
        if len(tracesB) > 0:
            grand_trace = TSeriesList(tracesB).mean()
            trace_plt[1].plot(grand_trace.time_values, grand_trace.data, pen={'color': color, 'width': 3})
            trace_plt[1].setLabels(right=('', 'V'), bottom=('Response Onset', 's'))
            trace_plt[1].setTitle('Current Clamp')
        return line, scatter