예제 #1
0
def format_trace(trace, baseline_win, x_offset=1e-3, align='spike'):
    # align can be to the pre-synaptic spike (default) or the onset of the PSP ('psp')
    baseline = float_mode(trace.time_slice(baseline_win[0],baseline_win[1]).data)
    trace = TSeries(data=(trace.data-baseline), sample_rate=db.default_sample_rate)
    if align == 'psp':
        trace.t0 = x_offset
    return trace
예제 #2
0
def create_test_pulse(start=5 * ms,
                      pdur=10 * ms,
                      pamp=-10 * pA,
                      mode='ic',
                      dt=10 * us,
                      r_access=10 * MOhm,
                      c_soma=5 * pF,
                      noise=5 * pA):
    # update patch pipette access resistance
    model_cell.clamp.ra = r_access

    # update noise amplitude
    model_cell.mechs['noise'].stdev = noise

    # make pulse array
    duration = start + pdur * 3
    pulse = np.zeros(int(duration / dt))
    pstart = int(start / dt)
    pstop = pstart + int(pdur / dt)
    pulse[pstart:pstop] = pamp

    # simulate response
    result = model_cell.test(TSeries(pulse, dt), mode)

    # generate a PatchClampTestPulse to test against
    tp = PatchClampTestPulse(result)

    return tp
예제 #3
0
 def baseline_tseries(self):
     bl = self.baseline
     if bl is None:
         return None
     return TSeries(bl.data,
                    sample_rate=default_sample_rate,
                    t0=bl.data_start_time)
def simulate_response(fg_recs, bg_results, amp, rtime, seed=None):
    if seed is not None:
        np.random.seed(seed)

    dt = 1.0 / db.default_sample_rate
    t = np.arange(0, 15e-3, dt)
    template = Psp.psp_func(t, xoffset=0, yoffset=0, rise_time=rtime, decay_tau=15e-3, amp=1, rise_power=2)

    r_amps = scipy.stats.binom.rvs(p=0.2, n=24, size=len(fg_recs)) * scipy.stats.norm.rvs(scale=0.3, loc=1, size=len(fg_recs))
    r_amps *= amp / r_amps.mean()
    r_latency = np.random.normal(size=len(fg_recs), scale=200e-6, loc=13e-3)
    fg_results = []
    traces = []
    fg_recs = [RecordWrapper(rec) for rec in fg_recs]  # can't modify fg_recs, so we wrap records with a mutable shell
    for k,rec in enumerate(fg_recs):
        rec.data = rec.data.copy()
        start = int(r_latency[k] * db.default_sample_rate)
        length = len(rec.data) - start
        rec.data[start:] += template[:length] * r_amps[k]

        fg_result = analyze_response_strength(rec, 'baseline')
        fg_results.append(fg_result)

        traces.append(TSeries(rec.data, sample_rate=db.default_sample_rate))
        traces[-1].amp = r_amps[k]
    fg_results = str_analysis_result_table(fg_results, fg_recs)
    conn_result = analyze_pair_connectivity({('ic', 'fg'): fg_results, ('ic', 'bg'): bg_results, ('vc', 'fg'): [], ('vc', 'bg'): []}, sign=1)
    return conn_result, traces
    def plot_all(self):
        self.plots.clear()
        self.plots.set_shape(len(self.sorted_recs), 1)
        psp = StackedPsp()

        stim_keys = sorted(list(self.sorted_recs.keys()))
        for i, stim_key in enumerate(stim_keys):
            prs = self.sorted_recs[stim_key]
            plt = self.plots[i, 0]
            plt.setTitle("%s  %0.0f Hz  %0.2f s" % stim_key)

            for recording in prs:
                pulses = sorted(list(prs[recording].keys()))
                for pulse_n in pulses:
                    rec = prs[recording][pulse_n]
                    # spike-align pulse + offset for pulse number
                    spike_t = rec.stim_pulse.first_spike_time
                    if spike_t is None:
                        spike_t = rec.stim_pulse.onset_time + 1e-3

                    qc_pass = rec.pulse_response.in_qc_pass if rec.synapse.synapse_type == 'in' else rec.pulse_response.ex_qc_pass
                    pen = (255, 255, 255, 100) if qc_pass else (100, 0, 0, 100)

                    t0 = rec.pulse_response.data_start_time - spike_t
                    ts = TSeries(data=rec.data,
                                 t0=t0,
                                 sample_rate=db.default_sample_rate)
                    c = plt.plot(ts.time_values, ts.data, pen=pen)

                    # arrange plots nicely
                    shift = (pulse_n * 35e-3 + (30e-3 if pulse_n > 8 else 0),
                             0)
                    c.setPos(*shift)

                    if not qc_pass:
                        c.setZValue(-10)
                        continue

                    # evaluate recorded fit for this response
                    fit_par = rec.pulse_response_fit
                    if fit_par.fit_amp is None:
                        continue
                    fit = psp.eval(
                        x=ts.time_values,
                        exp_amp=fit_par.fit_exp_amp,
                        exp_tau=fit_par.fit_decay_tau,
                        amp=fit_par.fit_amp,
                        rise_time=fit_par.fit_rise_time,
                        decay_tau=fit_par.fit_decay_tau,
                        xoffset=fit_par.fit_latency,
                        yoffset=fit_par.fit_yoffset,
                        rise_power=2,
                    )
                    c = plt.plot(ts.time_values, fit, pen=(0, 255, 0, 100))
                    c.setZValue(10)
                    c.setPos(*shift)
def format_responses(responses):
    n_trials = len(responses['data'])
    response = {}
    if n_trials != 0:
        for trial in range(n_trials):
            stim_params = responses['stim_param'][trial]
            if stim_params not in response:
                response[stim_params] = []
            response[stim_params].append(
                TSeries(data=responses['data'][trial],
                        dt=responses['dt'][trial],
                        stim_param=[responses['stim_param'][trial]]))
    return response
def test_threshold_events():
    empty_result = np.array([], dtype=dtype)

    d = TSeries(np.zeros(10), dt=0.1)
    
    check_events(threshold_events(d, 1), empty_result)
    check_events(threshold_events(d, 0), empty_result)
    
    d.data[5:7] = 6
    
    ev = threshold_events(d, 1)
    expected = np.array([(5, 2, 12., 6., 5, 0.5, 0.2, 0.6, 0.5)], dtype=dtype)
    check_events(threshold_events(d, 1), expected)
    
    d.data[2:4] = -6
    expected = np.array([
        (2, 2, -12., -6., 2, 0.2, 0.2, -0.6, 0.2),
        (5, 2,  12.,  6., 5, 0.5, 0.2,  0.6, 0.5)],
        dtype=dtype
    )
    check_events(threshold_events(d, 1), expected)
        
    # data ends above threshold
    d.data[:] = 0
    d.data[5:] = 6
    check_events(threshold_events(d, 1), empty_result)
    expected = np.array([(5, 5, 30., 6., 5, 0.5, 0.5, 2.4, 0.5)], dtype=dtype)
    check_events(threshold_events(d, 1, omit_ends=False), expected)

    # data begins above threshold
    d.data[:] = 6
    d.data[5:] = 0
    check_events(threshold_events(d, 1), empty_result)
    expected = np.array([(0, 5, 30., 6., 0, 0., 0.5, 2.4, 0.)], dtype=dtype)    
    check_events(threshold_events(d, 1, omit_ends=False), expected)

    # all points above threshold
    d.data[:] = 6
    check_events(threshold_events(d, 1), empty_result)
    expected = np.array([(0, 10, 60., 6., 0, 0., 1., 5.4, 0.)], dtype=dtype)
    check_events(threshold_events(d, 1, omit_ends=False), expected)
def test_exp_deconv_psp_params():
    from neuroanalysis.event_detection import exp_deconvolve, exp_deconv_psp_params
    from neuroanalysis.data import TSeries
    from neuroanalysis.fitting import Psp

    x = np.linspace(0, 0.02, 10000)
    amp = 1
    rise_time = 4e-3
    decay_tau = 10e-3
    rise_power = 2

    # Make a PSP waveform
    psp = Psp()
    y = psp.eval(x=x,
                 xoffset=0,
                 yoffset=0,
                 amp=amp,
                 rise_time=rise_time,
                 decay_tau=decay_tau,
                 rise_power=rise_power)

    # exponential deconvolution
    y_ts = TSeries(y, time_values=x)
    y_deconv = exp_deconvolve(y_ts, decay_tau).data

    # show that we can get approximately the same result using exp_deconv_psp_params
    d_amp, d_rise_time, d_rise_power, d_decay_tau = exp_deconv_psp_params(
        amp, rise_time, rise_power, decay_tau)
    y2 = psp.eval(x=x,
                  xoffset=0,
                  yoffset=0,
                  amp=d_amp,
                  rise_time=d_rise_time,
                  decay_tau=d_decay_tau,
                  rise_power=d_rise_power)

    assert np.allclose(y_deconv, y2[1:], atol=0.02)
def load_next():
    global all_pulses, ui, last_result
    try:
        (expt_id, cell_id, sweep, channel, chunk) = next(all_pulses)
    except StopIteration:
        ui.widget.hide()
        return

    # run spike detection on each chunk
    pulse_edges = chunk.meta['pulse_edges']
    spikes = detect_evoked_spikes(chunk, pulse_edges, ui=ui)
    ui.show_result(spikes)

    # copy just the necessary parts of recording data for export to file
    export_chunk = PatchClampRecording(
        channels={
            k: TSeries(chunk[k].data,
                       t0=chunk[k].t0,
                       sample_rate=chunk[k].sample_rate)
            for k in chunk.channels
        })
    export_chunk.meta.update(chunk.meta)

    # construct test case
    tc = SpikeDetectTestCase()
    tc._meta = {
        'expt_id': expt_id,
        'cell_id': cell_id,
        'device_id': channel,
        'sweep_id': sweep.key,
    }
    tc._input_args = {
        'data': export_chunk,
        'pulse_edges': chunk.meta['pulse_edges'],
    }
    last_result = tc
예제 #10
0
파일: data.py 프로젝트: shixnya/aisynphys
 def recorded_tseries(self):
     if self._rec_tseries is None:
         self._rec_tseries = TSeries(self.data,
                                     sample_rate=default_sample_rate,
                                     t0=self.data_start_time)
     return self._rec_tseries
예제 #11
0
    on_times = np.argwhere(diff > 0)[:, 0]
    off_times = np.argwhere(diff < 0)[:, 0]

    # decide on the region of the trace to focus on
    start = on_times[1] - 1000
    stop = off_times[8] + 1000
    chunk = trace[start:stop]

    # plot the selected chunk
    t = np.arange(chunk.shape[0]) * dt
    plot.plot(t[:-1], np.diff(ndi.gaussian_filter(chunk, sigma)), pen=0.5)
    plot.plot(t, chunk)

    # detect spike times
    peak_inds = []
    rise_inds = []
    for j in range(8):  # loop over pulses
        pstart = on_times[j + 1] - start
        pstop = off_times[j + 1] - start
        spike_info = detect_vc_evoked_spike(TSeries(chunk, dt=dt),
                                            pulse_edges=(pstart, pstop))
        if spike_info is not None:
            peak_inds.append(spike_info['peak_index'])
            rise_inds.append(spike_info['rise_index'])

    # display spike rise and peak times as ticks
    pticks = pg.VTickGroup(np.array(peak_inds) * dt, yrange=[0, 0.3], pen='r')
    rticks = pg.VTickGroup(np.array(rise_inds) * dt, yrange=[0, 0.3], pen='y')
    plot.addItem(pticks)
    plot.addItem(rticks)
예제 #12
0
def analyze_pair_connectivity(amps, sign=None):
    """Given response strength records for a single pair, generate summary
    statistics characterizing strength, latency, and connectivity.
    
    Parameters
    ----------
    amps : dict
        Contains foreground and background strength analysis records
        (see input format below)
    sign : None, -1, or +1
        If None, then automatically determine whether to treat this connection as
        inhibitory or excitatory.

    Input must have the following structure::
    
        amps = {
            ('ic'): recs, 
            ('vc'): recs, 
        }
        
    Where each *recs* must be a structured array containing fields as returned
    by get_amps().
    
    The overall strategy here is:
    
    1. Make an initial decision on whether to treat this pair as excitatory or
       inhibitory, based on differences between foreground and background amplitude
       measurements
    2. Generate mean and stdev for amplitudes, deconvolved amplitudes, and deconvolved
       latencies
    3. Generate KS test p values describing the differences between foreground
       and background distributions for amplitude, deconvolved amplitude, and
       deconvolved latency    
    """
    # Filter by QC
    for k, v in amps.items():
        mask = v['qc_pass'].astype(bool)
        amps[k] = v[mask]

    # See if any data remains
    if all([len(a) == 0 for a in amps]):
        return None

    requested_sign = sign
    fields = {}  # used to fill the new DB record

    # Use KS p value to check for differences between foreground and background
    qc_amps = {}
    ks_pvals = {}
    amp_means = {}
    amp_diffs = {}
    for clamp_mode in ('ic', 'vc'):
        clamp_mode_amps = amps[clamp_mode]
        if len(clamp_mode_amps) == 0:
            continue
        for sign in ('pos', 'neg'):
            # Separate into positive/negative tests and filter out responses that failed qc
            qc_field = {
                'vc': {
                    'pos': 'in_qc_pass',
                    'neg': 'ex_qc_pass'
                },
                'ic': {
                    'pos': 'ex_qc_pass',
                    'neg': 'in_qc_pass'
                }
            }[clamp_mode][sign]
            fg = clamp_mode_amps[clamp_mode_amps[qc_field]]
            qc_amps[sign, clamp_mode] = fg
            if len(fg) == 0:
                continue

            # Measure some statistics from these records
            bg = fg['baseline_' + sign + '_dec_amp']
            fg = fg[sign + '_dec_amp']
            pval = scipy.stats.ks_2samp(fg, bg).pvalue
            ks_pvals[(sign, clamp_mode)] = pval
            # we could ensure that the average amplitude is in the right direction:
            fg_mean = np.mean(fg)
            bg_mean = np.mean(bg)
            amp_means[sign, clamp_mode] = {'fg': fg_mean, 'bg': bg_mean}
            amp_diffs[sign, clamp_mode] = fg_mean - bg_mean

    if requested_sign is None:
        # Decide whether to treat this connection as excitatory or inhibitory.
        #   strategy: accumulate evidence for either possibility by checking
        #   the ks p-values for each sign/clamp mode and the direction of the deflection
        is_exc = 0
        # print(expt.acq_timestamp, pair.pre_cell.ext_id, pair.post_cell.ext_id)
        for sign in ('pos', 'neg'):
            for mode in ('ic', 'vc'):
                ks = ks_pvals.get((sign, mode), None)
                if ks is None:
                    continue
                # turn p value into a reasonable scale factor
                ks = norm_pvalue(ks)
                dif_sign = 1 if amp_diffs[sign, mode] > 0 else -1
                if mode == 'vc':
                    dif_sign *= -1
                is_exc += dif_sign * ks
                # print("    ", sign, mode, is_exc, dif_sign * ks)
    else:
        is_exc = requested_sign

    if is_exc > 0:
        fields['synapse_type'] = 'ex'
        signs = {'ic': 'pos', 'vc': 'neg'}
    else:
        fields['synapse_type'] = 'in'
        signs = {'ic': 'neg', 'vc': 'pos'}

    # compute the rest of statistics for only positive or negative deflections
    for clamp_mode in ('ic', 'vc'):
        sign = signs[clamp_mode]
        fg = qc_amps.get((sign, clamp_mode))
        if fg is None or len(fg) == 0:
            fields[clamp_mode + '_n_samples'] = 0
            continue

        fields[clamp_mode + '_n_samples'] = len(fg)
        fields[clamp_mode + '_crosstalk_mean'] = np.mean(fg['crosstalk'])
        fields[clamp_mode + '_base_crosstalk_mean'] = np.mean(
            fg['baseline_crosstalk'])

        # measure mean, stdev, and statistical differences between
        # fg and bg for each measurement
        for val, field in [('amp', 'amp'), ('deconv_amp', 'dec_amp'),
                           ('latency', 'dec_latency')]:
            f = fg[sign + '_' + field]
            b = fg['baseline_' + sign + '_' + field]
            fields[clamp_mode + '_' + val + '_mean'] = np.mean(f)
            fields[clamp_mode + '_' + val + '_stdev'] = np.std(f)
            fields[clamp_mode + '_base_' + val + '_mean'] = np.mean(b)
            fields[clamp_mode + '_base_' + val + '_stdev'] = np.std(b)
            # statistical tests comparing fg vs bg
            # Note: we use log(1-log(pval)) because it's nicer to plot and easier to
            # use as a classifier input
            tt_pval = scipy.stats.ttest_ind(f, b, equal_var=False).pvalue
            ks_pval = scipy.stats.ks_2samp(f, b).pvalue
            fields[clamp_mode + '_' + val + '_ttest'] = norm_pvalue(tt_pval)
            fields[clamp_mode + '_' + val + '_ks2samp'] = norm_pvalue(ks_pval)

        ### generate the average response and psp fit

        # collect all bg and fg traces
        fg_traces = TSeriesList()
        for rec in fg:
            if not np.isfinite(
                    rec['max_slope_time']) or rec['max_slope_time'] is None:
                continue
            t0 = rec['response_start_time'] - rec[
                'max_slope_time']  # time-align to presynaptic spike
            trace = TSeries(rec['data'],
                            sample_rate=db.default_sample_rate,
                            t0=t0)
            fg_traces.append(trace)

        # get averages

        if len(fg_traces) == 0:
            continue

        # bg_avg = bg_traces.mean()
        fg_avg = fg_traces.mean()
        base_rgn = fg_avg.time_slice(-6e-3, 0)
        base = float_mode(base_rgn.data)
        fields[clamp_mode + '_average_response'] = fg_avg.data
        fields[clamp_mode + '_average_response_t0'] = fg_avg.t0
        fields[clamp_mode + '_average_base_stdev'] = base_rgn.std()

        sign = {'pos': 1, 'neg': -1}[signs[clamp_mode]]
        fg_bsub = fg_avg.copy(data=fg_avg.data -
                              base)  # remove base to help fitting
        try:
            fit = fit_psp(fg_bsub,
                          clamp_mode=clamp_mode,
                          sign=sign,
                          search_window=[0, 6e-3])
            for param, val in fit.best_values.items():
                fields['%s_fit_%s' % (clamp_mode, param)] = val
            fields[clamp_mode +
                   '_fit_yoffset'] = fit.best_values['yoffset'] + base
            fields[clamp_mode + '_fit_nrmse'] = fit.nrmse()
        except:
            print("Error in PSP fit:")
            sys.excepthook(*sys.exc_info())
            continue

        #global fit_plot
        #if fit_plot is None:
        #fit_plot = FitExplorer(fit)
        #fit_plot.show()
        #else:
        #fit_plot.set_fit(fit)
        #raw_input("Waiting to continue..")

    return fields
예제 #13
0
def test_trace_timing():
    # Make sure sample timing is handled exactly--need to avoid fp error here
    a = np.random.normal(size=300)
    sr = 50000
    dt = 2e-5
    t = np.arange(len(a)) * dt
    
    # trace with no timing information 
    tr = TSeries(a)
    assert not tr.has_timing
    assert not tr.has_time_values
    with raises(TypeError):
        tr.dt
    with raises(TypeError):
        tr.sample_rate
    with raises(TypeError):
        tr.time_values
    with raises(TypeError):
        tr.time_at(0)
    with raises(TypeError):
        tr.index_at(0.1)
    with raises(TypeError):
        tr.value_at(0.1)
        
    view = tr[100:200]
    assert not tr.has_timing
    assert not tr.has_time_values

    # invalid data
    with raises(ValueError):
        TSeries(data=np.zeros((10, 10)))

    # invalid timing information
    with raises(TypeError):
        TSeries(data=a, dt=dt, time_values=t)
    with raises(TypeError):
        TSeries(data=a, sample_rate=sr, time_values=t)
    with raises(TypeError):
        TSeries(data=a, dt=dt, t0=0, time_values=t)
    with raises(TypeError):
        TSeries(data=a, dt=dt, t0=0, sample_rate=sr)
    with raises(ValueError):
        TSeries(data=a, time_values=t[:-1])

    # trace with only dt
    tr = TSeries(a, dt=dt)
    assert tr.dt == dt
    assert np.allclose(tr.sample_rate, sr)
    check_trace(tr, data=a, time_values=t, has_timing=True, has_time_values=False, regularly_sampled=True)

    # trace with only sample_rate
    tr = TSeries(a, sample_rate=sr)
    assert tr.dt == dt
    assert tr.sample_rate == sr
    assert np.all(tr.time_values == t)
    check_trace(tr, data=a, time_values=t, has_timing=True, has_time_values=False, regularly_sampled=True)
    
    # trace with only regularly-sampled time_values
    tr = TSeries(a, time_values=t)
    assert tr.dt == dt
    assert np.allclose(tr.sample_rate, sr)
    assert np.all(tr.time_values == t)
    check_trace(tr, data=a, time_values=t, has_timing=True, has_time_values=True, regularly_sampled=True)

    # trace with irregularly-sampled time values
    t1 = np.cumsum(np.random.normal(loc=1, scale=0.02, size=a.shape))
    tr = TSeries(a, time_values=t1)
    assert tr.dt == t1[1] - t1[0]
    assert np.all(tr.time_values == t1)
    check_trace(tr, data=a, time_values=t1, has_timing=True, has_time_values=True, regularly_sampled=False)
예제 #14
0
def extract_first_pulse_info_from_Pair_object(pair, desired_clamp='ic'):
    """Extract first pulse responses and relevant information 
    from entry in the pair database. Screen out pulses that are
    not current clamp or do not pass the corresponding
    inhibitory or excitatory qc.
    
    Input
    -----
    pair: aisynphys.database.database.Pair object
    desired_clamp: string
        Specifies whether current or voltage clamp sweeps are desired.
        Options are:
            'ic': current clamp
            'vc': voltage clamp
    
    Return
    ------
    pulse_responses: TSeriesList 
        traces where the start of each trace is 10 ms before the spike 
    pulse_ids: list of ints
        pulse ids of *pulse_responses*
    psp_amps_measured: list of floats
        amplitude of *pulse_responses* from the *pulse_response* table
    stim_freq: list of floats
        the stimulation frequency corresponding to the *pulse_responses* 
    """

    if pair.synapse_prediction is None:
        # print ("\t\tSKIPPING: pair_id %s, is not yielding pair.synapse_prediction" % pair.id)
        return [], [], [], []
    if pair.synapse_prediction.synapse_type is None:
        # print ("\t\tSKIPPING: pair_id %s, is not yielding pair.synapse_prediction.synapse_type" % pair.id)
        return [], [], [], []
    synapse_type = pair.synapse_prediction.synapse_type
    pulse_responses = []
    psp_amps_measured = []
    pulse_ids = []
    stim_freqs = []
    if len(pair.pulse_responses) == 0:
        # print ("\t\tSKIPPING: pair_id %s, no pulse responses in pair table" % (pair.id))
        return [], [], [], []
    for pr in pair.pulse_responses:
        stim_pulse = pr.stim_pulse
        n_spikes = stim_pulse.n_spikes
        pulse_number = stim_pulse.pulse_number
        pulse_id = pr.stim_pulse_id
        ex_qc_pass = pr.ex_qc_pass
        in_qc_pass = pr.in_qc_pass
        pcr = stim_pulse.recording.patch_clamp_recording
        stim_freq = pcr.multi_patch_probe.induction_frequency
        clamp_mode = pcr.clamp_mode
        # current clamp
        if clamp_mode != desired_clamp:
            continue
        # ensure that there was only 1 presynaptic spike
        if n_spikes != 1:
            continue
        # we only want the first pulse of the train
        if pulse_number != 1:
            continue

        data = pr.data
        start_time = pr.data_start_time
        spike_time = stim_pulse.spikes[0].max_slope_time
        if spike_time is None:
            continue
        data_trace = TSeries(
            data=data,
            t0=start_time - spike_time + time_before_spike,
            sample_rate=db.default_sample_rate).time_slice(
                start=0, stop=None)  #start of the data is the spike time

        # append to output lists if neurons pass qc
        if (synapse_type == 'ex'
                and ex_qc_pass is True) or (synapse_type == 'in'
                                            and in_qc_pass is True):
            pulse_responses.append(data_trace)
            pulse_ids.append(pulse_id)
            stim_freqs.append(stim_freq)
        if synapse_type == 'in' and in_qc_pass is True:
            psp_amps_measured.append(pr.pulse_response_strength.neg_amp)
        if synapse_type == 'ex' and ex_qc_pass is True:
            psp_amps_measured.append(pr.pulse_response_strength.pos_amp)

    return pulse_responses, pulse_ids, psp_amps_measured, stim_freq
예제 #15
0
def fit_single_first_pulse(pr, pair):
    #TODO: HAS THE APPROPRIATE QC HAPPENED?
    message = None  #initialize error message for downstream processing
    # excitatory or inhibitory?
    excitation = pair.synapse_prediction.synapse_type
    if not excitation:
        raise Exception('there is no synapse_type in synapse_prediction')

    if excitation == 'in':
        if not pr.in_qc_pass:
            return {'error': 'this pulse does not pass inhibitory qc'}
    if excitation == 'ex':
        if not pr.ex_qc_pass:
            return {'error': 'this pulse does not pass excitatory qc'}

    # get response latency from average first pulse table
    if not pair.avg_first_pulse_fit:
        return {'error': 'no entry in avg_first_pulse_fit table for this pair'}

    if pr.clamp_mode == 'vc':
        weight_i = np.array([0])
        latency_i = None
        amp_i = None
        rise_time_i = None
        decay_tau_i = None
        data_waveform_i = np.array([0])
        fit_waveform_i = np.array([0])
        dt_i = None
        nrmse_i = None
        if pair.avg_first_pulse_fit.vc_latency:
            data_trace = TSeries(
                data=pr.data,
                t0=pr.response_start_time - pr.spike_time + time_before_spike,
                sample_rate=db.default_sample_rate).time_slice(start=0,
                                                               stop=None)
            xoffset = pair.avg_first_pulse_fit.vc_latency
            # weight and fit the trace
            weight_v = np.ones(len(
                data_trace.data)) * 10.  #set everything to ten initially
            weight_v[int((time_before_spike + .0001 + xoffset) /
                         data_trace.dt):int(
                             (time_before_spike + .0001 + xoffset + 4e-3) /
                             data_trace.dt)] = 30.  #area around steep PSP rise
            fit_v = fit_trace(data_trace,
                              excitation=excitation,
                              clamp_mode='vc',
                              weight=weight_v,
                              latency=xoffset,
                              latency_jitter=.5e-3)
            latency_v = fit_v.best_values['xoffset'] - time_before_spike
            amp_v = fit_v.best_values['amp']
            rise_time_v = fit_v.best_values['rise_time']
            decay_tau_v = fit_v.best_values['decay_tau']
            data_waveform_v = data_trace.data
            fit_waveform_v = fit_v.best_fit
            dt_v = data_trace.dt
            nrmse_v = fit_v.nrmse()

        else:
            return {
                'error':
                'no vc_latency available from avg_first_pulse_fit table'
            }  #no row will be made in the table because the error message is not none

    elif pr.clamp_mode == 'ic':
        # set voltage to none since this is current clamp
        weight_v = np.array([0])
        latency_v = None
        amp_v = None
        rise_time_v = None
        decay_tau_v = None
        data_waveform_v = np.array([0])
        fit_waveform_v = np.array([0])
        dt_v = None
        nrmse_v = None
        if pair.avg_first_pulse_fit.ic_latency:
            data_trace = TSeries(
                data=pr.data,
                t0=pr.response_start_time - pr.spike_time + time_before_spike,
                sample_rate=db.default_sample_rate).time_slice(
                    start=0, stop=None
                )  #TODO: annoys me that this is repetitive in vc code above.
            xoffset = pair.avg_first_pulse_fit.ic_latency
            # weight and fit the trace
            weight_i = np.ones(len(
                data_trace.data)) * 10.  #set everything to ten initially
            weight_i[int((time_before_spike - 3e-3) / data_trace.dt):int(
                time_before_spike / data_trace.dt
            )] = 0.  #area around stim artifact note that since this is spike aligned there will be some blur in where the cross talk is
            weight_i[int((time_before_spike + .0001 + xoffset) /
                         data_trace.dt):int(
                             (time_before_spike + .0001 + xoffset + 4e-3) /
                             data_trace.dt)] = 30.  #area around steep PSP rise
            fit_i = fit_trace(data_trace,
                              excitation=excitation,
                              weight=weight_i,
                              latency=xoffset,
                              latency_jitter=.5e-3)
            latency_i = fit_i.best_values['xoffset'] - time_before_spike
            amp_i = fit_i.best_values['amp']
            rise_time_i = fit_i.best_values['rise_time']
            decay_tau_i = fit_i.best_values['decay_tau']
            data_waveform_i = data_trace.data
            fit_waveform_i = fit_i.best_fit
            dt_i = data_trace.dt
            nrmse_i = fit_i.nrmse()

        else:
            return {
                'error':
                'no ic_latency available from avg_first_pulse_fit table'
            }  #no row will be made in the table because the error message is not none

    else:
        raise Exception('There is no clamp mode associated with this pulse')

    #------------ done with fitting section ------------------------------

    # dictionary for ease of translation into the output table
    out_dict = {
        'ic_amp': amp_i,
        'ic_latency': latency_i,
        'ic_rise_time': rise_time_i,
        'ic_decay_tau': decay_tau_i,
        'ic_psp_data': data_waveform_i,
        'ic_psp_fit': fit_waveform_i,
        'ic_dt': dt_i,
        'ic_nrmse': nrmse_i,
        'vc_amp': amp_v,
        'vc_latency': latency_v,
        'vc_rise_time': rise_time_v,
        'vc_decay_tau': decay_tau_v,
        'vc_psp_data': data_waveform_v,
        'vc_psp_fit': fit_waveform_v,
        'vc_dt': dt_v,
        'vc_nrmse': nrmse_v,
        'error': message
    }

    return out_dict
예제 #16
0
def filter_pulse_responses(pair):
    ### get first pulse response if it passes qc for excitatory or inhibitory analysis

    # TODO: learn how to do what's below in one query
    # s = db.session()
    # q = s.query(db.PulseResponse.data, db.StimSpike, db.PatchClampRecording)
    # q = q.join(db.StimPulse).join(db.StimSpike).join(db.PatchClampRecording)
    # filters = [
    #     (db.Pair == pair)
    #     (db.StimPulse.pulse_number == 1),
    #     (db.StimPulse.n_spikes == 1),
    #     (db.StimSpike.max_dvdt_time != None),
    #     (db.PulseResponse.ex_qc_pass == True)
    #     (db.PatchClampRecording.clamp_mode == 'ic')
    # ]
    #
    # for filter_arg in filters:
    #     q = q.filter(*filter_arg)

    synapse_type = pair.synapse_prediction.synapse_type
    pulse_responses = []
    pulse_response_amps = []
    pulse_ids = []
    for pr in pair.pulse_responses:
        stim_pulse = pr.stim_pulse
        n_spikes = stim_pulse.n_spikes
        pulse_number = stim_pulse.pulse_number
        pulse_id = pr.stim_pulse_id
        ex_qc_pass = pr.ex_qc_pass
        in_qc_pass = pr.in_qc_pass
        pcr = stim_pulse.recording.patch_clamp_recording
        stim_freq = pcr.multi_patch_probe[0].induction_frequency
        clamp_mode = pcr.clamp_mode
        # current clamp
        if clamp_mode != 'ic':
            continue
        # ensure that there was only 1 presynaptic spike
        if n_spikes != 1:
            continue
        # we only want the first pulse of the train
        if pulse_number != 1:
            continue
        # only include frequencies up to 50Hz
        if stim_freq > 50:
            continue

        data = pr.data
        start_time = pr.start_time
        spike_time = stim_pulse.spikes[0].max_dvdt_time
        data_trace = TSeries(data=data,
                             t0=start_time - spike_time,
                             sample_rate=db.default_sample_rate)

        if synapse_type == 'ex' and ex_qc_pass is True:
            pulse_responses.append(data_trace)
            pulse_ids.append(pulse_id)
            pulse_response_amps.append(pr.pulse_response_strength.pos_amp)
        if synapse_type == 'in' and in_qc_pass is True:
            pulse_responses.append(data_trace)
            pulse_ids.append(pulse_id)
            pulse_response_amps.append(pr.pulse_response_strength.neg_amp)

    return pulse_responses, pulse_ids, pulse_response_amps
예제 #17
0
 def plot_element_data(self, pre_class, post_class, element, field_name, color='g', trace_plt=None):
     val = element[field_name].mean()
     line = pg.InfiniteLine(val, pen={'color': color, 'width': 2}, movable=False)
     scatter = None
     baseline_window = int(db.default_sample_rate * 5e-3)
     values = []
     tracesA = []
     tracesB = []
     point_data = []
     for pair, value in element[field_name].iteritems():
         latency = self.results.loc[pair]['Latency']
         trace_itemA = None
         trace_itemB = None
         if pair.has_synapse is not True:
             continue
         if np.isnan(value):
             continue
         syn_typ = pair.synapse.synapse_type
         rsf = pair.resting_state_fit
         if rsf is not None:
             nrmse = rsf.vc_nrmse if field_name.startswith('PSC') else rsf.ic_nrmse
             # if nrmse is None or nrmse > 0.8:
             #     continue
             data = rsf.vc_avg_data if field_name.startswith('PSC') else rsf.ic_avg_data
             traceA = TSeries(data=data, sample_rate=db.default_sample_rate)
             if field_name.startswith('PSC'):
                 traceA = bessel_filter(traceA, 5000, btype='low', bidir=True)
             bessel_filter(traceA, 5000, btype='low', bidir=True)
             start_time = rsf.vc_avg_data_start_time if field_name.startswith('PSC') else rsf.ic_avg_data_start_time
             if latency is not None and start_time is not None:
                 if field_name == 'Latency':
                     xoffset = start_time + latency
                 else:
                     xoffset = start_time - latency
                 baseline_window = [abs(xoffset)-1e-3, abs(xoffset)]
                 traceA = format_trace(traceA, baseline_window, x_offset=xoffset, align='psp')
                 trace_itemA = trace_plt[1].plot(traceA.time_values, traceA.data)
                 trace_itemA.pair = pair
                 trace_itemA.curve.setClickable(True)
                 trace_itemA.sigClicked.connect(self.trace_plot_clicked)
                 tracesA.append(traceA)
             if field_name == 'Latency' and rsf.vc_nrmse is not None: #and rsf.vc_nrmse < 0.8:
                 traceB = TSeries(data=rsf.vc_avg_data, sample_rate=db.default_sample_rate)
                 traceB = bessel_filter(traceB, 5000, btype='low', bidir=True)
                 start_time = rsf.vc_avg_data_start_time
                 if latency is not None and start_time is not None:
                     xoffset = start_time + latency
                     baseline_window = [abs(xoffset)-1e-3, abs(xoffset)]
                     traceB = format_trace(traceB, baseline_window, x_offset=xoffset, align='psp')
                     trace_itemB = trace_plt[0].plot(traceB.time_values, traceB.data)
                     trace_itemB.pair = pair
                     trace_itemB.curve.setClickable(True)
                     trace_itemB.sigClicked.connect(self.trace_plot_clicked)
                     tracesB.append(traceB)
         self.pair_items[pair.id] = [trace_itemA, trace_itemB]
         if trace_itemA is not None:
             values.append(value)
             point_data.append(pair)
     y_values = pg.pseudoScatter(np.asarray(values, dtype=float), spacing=1)
     scatter = pg.ScatterPlotItem(symbol='o', brush=(color + (150,)), pen='w', size=12)
     scatter.setData(values, y_values + 10., data=point_data)
     for point in scatter.points():
         pair_id = point.data().id
         self.pair_items[pair_id].extend([point, color])
     scatter.sigClicked.connect(self.scatter_plot_clicked)
     if len(tracesA) > 0:
         if field_name == 'Latency':     
             spike_line = pg.InfiniteLine(0, pen={'color': 'w', 'width': 1, 'style': pg.QtCore.Qt.DotLine}, movable=False)
             trace_plt[0].addItem(spike_line)
             x_label = 'Time from presynaptic spike'
         else:
             x_label = 'Response Onset'
         grand_trace = TSeriesList(tracesA).mean()
         name = ('%s->%s, n=%d' % (pre_class, post_class, len(tracesA)))
         trace_plt[1].plot(grand_trace.time_values, grand_trace.data, pen={'color': color, 'width': 3}, name=name)
         units = 'A' if field_name.startswith('PSC') else 'V'
         title = 'Voltage Clamp' if field_name.startswith('PSC') else 'Current Clamp'
         trace_plt[1].setXRange(-5e-3, 20e-3)
         trace_plt[1].setLabels(left=('', units), bottom=(x_label, 's'))
         trace_plt[1].setTitle(title)
     if len(tracesB) > 0:
         trace_plt[1].setLabels(right=('', units))
         trace_plt[1].hideAxis('left')
         spike_line = pg.InfiniteLine(0, pen={'color': 'w', 'width': 1, 'style': pg.QtCore.Qt.DotLine}, movable=False)
         trace_plt[0].addItem(spike_line)
         grand_trace = TSeriesList(tracesB).mean()
         trace_plt[0].plot(grand_trace.time_values, grand_trace.data, pen={'color': color, 'width': 3})
         trace_plt[0].setXRange(-5e-3, 20e-3)
         trace_plt[0].setLabels(left=('', 'A'), bottom=('Time from presynaptic spike', 's'))
         trace_plt[0].setTitle('Voltage Clamp')
     return line, scatter
예제 #18
0
import pyqtgraph as pg
from pyqtgraph.Qt import QtGui, QtCore
import numpy as np
from neuroanalysis.data import TSeries
from neuroanalysis.ui.event_detection import EventDetector
from neuroanalysis.ui.plot_grid import PlotGrid

pg.mkQApp()

data = np.load("test_data/synaptic_events/events1.npz")
trace_names = sorted([x for x in data.keys() if x.startswith('trace')])
traces = {
    n: TSeries(data[n], dt=1.0 / data['sample_rates'][i])
    for i, n in enumerate(trace_names)
}

evd = EventDetector()
evd.params['threshold'] = 5e-10

hs = QtGui.QSplitter(QtCore.Qt.Horizontal)
pt = pg.parametertree.ParameterTree(showHeader=False)

params = pg.parametertree.Parameter.create(name='params',
                                           type='group',
                                           children=[
                                               dict(name='data',
                                                    type='list',
                                                    values=trace_names),
                                               evd.params,
                                           ])
예제 #19
0
    def set_data(self, data, show_spikes=False, subtract_baseline=True):
        self.spike_plot.setVisible(show_spikes)
        self.spike_plot.enableAutoRange(False, False)
        self.data_plot.enableAutoRange(False, False)
        psp = StackedPsp()

        for recording in data:
            pulses = sorted(list(data[recording].keys()))
            for pulse_n in pulses:
                rec = data[recording][pulse_n]
                # spike-align pulse + offset for pulse number
                spike_t = rec.StimPulse.first_spike_time
                if spike_t is None:
                    spike_t = rec.StimPulse.onset_time + 1e-3

                qc_pass = rec.PulseResponse.in_qc_pass if rec.Synapse.synapse_type == 'in' else rec.PulseResponse.ex_qc_pass
                pen = (255, 255, 255, 100) if qc_pass else (200, 50, 0, 100)

                t0 = rec.PulseResponse.data_start_time - spike_t
                ts = TSeries(data=rec.data,
                             t0=t0,
                             sample_rate=db.default_sample_rate)
                c = self.data_plot.plot(ts.time_values, ts.data, pen=pen)

                # arrange plots nicely
                y0 = 0 if not subtract_baseline else ts.time_slice(None,
                                                                   0).median()
                shift = (pulse_n * 35e-3 + (30e-3 if pulse_n > 8 else 0), -y0)
                zval = 0 if qc_pass else -10
                c.setPos(*shift)
                c.setZValue(zval)

                if show_spikes:
                    t0 = rec.spike_data_start_time - spike_t
                    spike_ts = TSeries(data=rec.spike_data,
                                       t0=t0,
                                       sample_rate=db.default_sample_rate)
                    c = self.spike_plot.plot(spike_ts.time_values,
                                             spike_ts.data,
                                             pen=pen)
                    c.setPos(*shift)
                    c.setZValue(zval)

                # evaluate recorded fit for this response
                fit_par = rec.PulseResponseFit
                if fit_par.fit_amp is None:
                    continue
                fit = psp.eval(
                    x=ts.time_values,
                    exp_amp=fit_par.fit_exp_amp,
                    exp_tau=fit_par.fit_decay_tau,
                    amp=fit_par.fit_amp,
                    rise_time=fit_par.fit_rise_time,
                    decay_tau=fit_par.fit_decay_tau,
                    xoffset=fit_par.fit_latency,
                    yoffset=fit_par.fit_yoffset,
                    rise_power=2,
                )
                pen = (0, 255, 0, 100) if qc_pass else (50, 150, 0, 100)
                c = self.data_plot.plot(ts.time_values, fit, pen=pen)
                c.setZValue(10)
                c.setPos(*shift)

                if not qc_pass:
                    print(
                        "qc fail: ",
                        rec.PulseResponse.meta.get('qc_failures',
                                                   'no qc failures recorded'))

        self.spike_plot.enableAutoRange(True, True)
        self.data_plot.enableAutoRange(True, True)
예제 #20
0
def analyze_response_strength(rec,
                              source,
                              remove_artifacts=False,
                              deconvolve=True,
                              lpf=True,
                              bsub=True,
                              lowpass=1000):
    """Perform a standardized strength analysis on a record selected by response_query or baseline_query.

    1. Determine timing of presynaptic stimulus pulse edges and spike
    2. Measure peak deflection on raw trace
    3. Apply deconvolution / artifact removal / lpf
    4. Measure peak deflection on deconvolved trace
    """
    data = TSeries(rec.data, sample_rate=db.default_sample_rate)
    if source == 'pulse_response':
        # Find stimulus pulse edges for artifact removal
        start = rec.pulse_start - rec.rec_start
        pulse_times = [start, start + rec.pulse_dur]
        if rec.spike_time is None:
            # these pulses failed QC, but we analyze them anyway to make all data visible
            spike_time = 11e-3
        else:
            spike_time = rec.spike_time - rec.rec_start
    elif source == 'baseline':
        # Fake stimulus information to ensure that background data receives
        # the same filtering / windowing treatment
        pulse_times = [10e-3, 12e-3]
        spike_time = 11e-3
    else:
        raise ValueError("Invalid source %s" % source)

    results = {}

    results['raw_trace'] = data
    results['pulse_times'] = pulse_times
    results['spike_time'] = spike_time

    # Measure crosstalk from pulse onset
    p1 = data.time_slice(pulse_times[0] - 200e-6, pulse_times[0]).median()
    p2 = data.time_slice(pulse_times[0], pulse_times[0] + 200e-6).median()
    results['crosstalk'] = p2 - p1

    # crosstalk artifacts in VC are removed before deconvolution
    if rec.clamp_mode == 'vc' and remove_artifacts is True:
        data = remove_crosstalk_artifacts(data, pulse_times)
        remove_artifacts = False

    # Measure deflection on raw data
    results['pos_amp'], _ = measure_peak(data, '+', spike_time, pulse_times)
    results['neg_amp'], _ = measure_peak(data, '-', spike_time, pulse_times)

    # Deconvolution / artifact removal / filtering
    if deconvolve:
        tau = 15e-3 if rec.clamp_mode == 'ic' else 5e-3
    else:
        tau = None
    dec_data = deconv_filter(data,
                             pulse_times,
                             tau=tau,
                             lpf=lpf,
                             remove_artifacts=remove_artifacts,
                             bsub=bsub,
                             lowpass=lowpass)
    results['dec_trace'] = dec_data

    # Measure deflection on deconvolved data
    results['pos_dec_amp'], results['pos_dec_latency'] = measure_peak(
        dec_data, '+', spike_time, pulse_times)
    results['neg_dec_amp'], results['neg_dec_latency'] = measure_peak(
        dec_data, '-', spike_time, pulse_times)

    return results
예제 #21
0
def measure_response(rec, baseline_rec):
    """Curve fit a single pulse response to measure its amplitude / kinetics.
    
    Uses the known latency and kinetics of the synapse to seed the fit.
    Optionally fit a baseline at the same time for noise measurement.
    """
    if rec.clamp_mode == 'ic':
        rise_time = rec.psp_rise_time
        decay_tau = rec.psp_decay_tau
    else:
        rise_time = rec.psc_rise_time
        decay_tau = rec.psc_decay_tau

    # make sure all parameters are available
    for v in [rec.spike_time, rec.latency, rise_time, decay_tau]:
        if v is None or rec.latency is None or not np.isfinite(v):
            return None, None

    data = TSeries(rec.data,
                   t0=rec.rec_start - rec.spike_time,
                   sample_rate=db.default_sample_rate)

    # decide whether/how to constrain the sign of the fit
    if rec.synapse_type == 'ex':
        sign = 1
    elif rec.synapse_type == 'in':
        if rec.baseline_potential > -60e-3:
            sign = -1
        else:
            sign = 0
    else:
        sign = 0
    if rec.clamp_mode == 'vc':
        sign = -sign

    # fit response region
    response_fit = fit_psp(
        data,
        search_window=rec.latency + np.array([-100e-6, 100e-6]),
        clamp_mode=rec.clamp_mode,
        sign=sign,
        baseline_like_psp=True,
        init_params={
            'rise_time': rise_time,
            'decay_tau': decay_tau
        },
        refine=False,
    )

    # fit baseline region
    if baseline_rec is None:
        baseline_fit = None
    else:
        baseline = TSeries(baseline_rec.data,
                           t0=data.t0,
                           sample_rate=db.default_sample_rate)
        baseline_fit = fit_psp(
            baseline,
            search_window=rec.latency + np.array([-100e-6, 100e-6]),
            clamp_mode=rec.clamp_mode,
            sign=sign,
            baseline_like_psp=True,
            init_params={
                'rise_time': rise_time,
                'decay_tau': decay_tau
            },
            refine=False,
        )

    return response_fit, baseline_fit
def save_fit_psp_test_set():
    """NOTE THIS CODE DOES NOT WORK BUT IS HERE FOR DOCUMENTATION PURPOSES SO 
    THAT WE CAN TRACE BACK HOW THE TEST DATA WAS CREATED IF NEEDED.
    Create a test set of data for testing the fit_psp function.  Uses Steph's 
    original first_puls_feature.py code to filter out error causing data.
    
    Example run statement
    python save save_fit_psp_test_set.py --organism mouse --connection ee
    
    Comment in the code that does the saving at the bottom
    """

    import pyqtgraph as pg
    import numpy as np
    import csv
    import sys
    import argparse
    from multipatch_analysis.experiment_list import cached_experiments
    from manuscript_figures import get_response, get_amplitude, response_filter, feature_anova, write_cache, trace_plot, \
        colors_human, colors_mouse, fail_rate, pulse_qc, feature_kw
    from synapse_comparison import load_cache, summary_plot_pulse
    from neuroanalysis.data import TSeriesList, TSeries
    from neuroanalysis.ui.plot_grid import PlotGrid
    from multipatch_analysis.connection_detection import fit_psp
    from rep_connections import ee_connections, human_connections, no_include, all_connections, ie_connections, ii_connections, ei_connections
    from multipatch_analysis.synaptic_dynamics import DynamicsAnalyzer
    from scipy import stats
    import time
    import pandas as pd
    import json
    import os

    app = pg.mkQApp()
    pg.dbg()
    pg.setConfigOption('background', 'w')
    pg.setConfigOption('foreground', 'k')

    parser = argparse.ArgumentParser(
        description=
        'Enter organism and type of connection you"d like to analyze ex: mouse ee (all mouse excitatory-'
        'excitatory). Alternatively enter a cre-type connection ex: sim1-sim1')
    parser.add_argument('--organism',
                        dest='organism',
                        help='Select mouse or human')
    parser.add_argument('--connection',
                        dest='connection',
                        help='Specify connections to analyze')
    args = vars(parser.parse_args(sys.argv[1:]))

    all_expts = cached_experiments()
    manifest = {
        'Type': [],
        'Connection': [],
        'amp': [],
        'latency': [],
        'rise': [],
        'rise2080': [],
        'rise1090': [],
        'rise1080': [],
        'decay': [],
        'nrmse': [],
        'CV': []
    }
    fit_qc = {'nrmse': 8, 'decay': 499e-3}

    if args['organism'] == 'mouse':
        color_palette = colors_mouse
        calcium = 'high'
        age = '40-60'
        sweep_threshold = 3
        threshold = 0.03e-3
        connection = args['connection']
        if connection == 'ee':
            connection_types = ee_connections.keys()
        elif connection == 'ii':
            connection_types = ii_connections.keys()
        elif connection == 'ei':
            connection_types = ei_connections.keys()
        elif connection == 'ie':
            connection_types == ie_connections.keys()
        elif connection == 'all':
            connection_types = all_connections.keys()
        elif len(connection.split('-')) == 2:
            c_type = connection.split('-')
            if c_type[0] == '2/3':
                pre_type = ('2/3', 'unknown')
            else:
                pre_type = (None, c_type[0])
            if c_type[1] == '2/3':
                post_type = ('2/3', 'unknown')
            else:
                post_type = (None, c_type[0])
            connection_types = [(pre_type, post_type)]
    elif args['organism'] == 'human':
        color_palette = colors_human
        calcium = None
        age = None
        sweep_threshold = 5
        threshold = None
        connection = args['connection']
        if connection == 'ee':
            connection_types = human_connections.keys()
        else:
            c_type = connection.split('-')
            connection_types = [((c_type[0], 'unknown'), (c_type[1],
                                                          'unknown'))]

    plt = pg.plot()

    scale_offset = (-20, -20)
    scale_anchor = (0.4, 1)
    holding = [-65, -75]
    qc_plot = pg.plot()
    grand_response = {}
    expt_ids = {}
    feature_plot = None
    feature2_plot = PlotGrid()
    feature2_plot.set_shape(5, 1)
    feature2_plot.show()
    feature3_plot = PlotGrid()
    feature3_plot.set_shape(1, 3)
    feature3_plot.show()
    amp_plot = pg.plot()
    synapse_plot = PlotGrid()
    synapse_plot.set_shape(len(connection_types), 1)
    synapse_plot.show()
    for c in range(len(connection_types)):
        cre_type = (connection_types[c][0][1], connection_types[c][1][1])
        target_layer = (connection_types[c][0][0], connection_types[c][1][0])
        conn_type = connection_types[c]
        expt_list = all_expts.select(cre_type=cre_type,
                                     target_layer=target_layer,
                                     calcium=calcium,
                                     age=age)
        color = color_palette[c]
        grand_response[conn_type[0]] = {
            'trace': [],
            'amp': [],
            'latency': [],
            'rise': [],
            'dist': [],
            'decay': [],
            'CV': [],
            'amp_measured': []
        }
        expt_ids[conn_type[0]] = []
        synapse_plot[c, 0].addLegend()
        for expt in expt_list:
            for pre, post in expt.connections:
                if [expt.uid, pre, post] in no_include:
                    continue
                cre_check = expt.cells[pre].cre_type == cre_type[
                    0] and expt.cells[post].cre_type == cre_type[1]
                layer_check = expt.cells[pre].target_layer == target_layer[
                    0] and expt.cells[post].target_layer == target_layer[1]
                if cre_check is True and layer_check is True:
                    pulse_response, artifact = get_response(
                        expt, pre, post, analysis_type='pulse')
                    if threshold is not None and artifact > threshold:
                        continue
                    response_subset, hold = response_filter(
                        pulse_response,
                        freq_range=[0, 50],
                        holding_range=holding,
                        pulse=True)
                    if len(response_subset) >= sweep_threshold:
                        qc_plot.clear()
                        qc_list = pulse_qc(response_subset,
                                           baseline=1.5,
                                           pulse=None,
                                           plot=qc_plot)
                        if len(qc_list) >= sweep_threshold:
                            avg_trace, avg_amp, amp_sign, peak_t = get_amplitude(
                                qc_list)
                            #                        if amp_sign is '-':
                            #                            continue
                            #                        #print ('%s, %0.0f' %((expt.uid, pre, post), hold, ))
                            #                        all_amps = fail_rate(response_subset, '+', peak_t)
                            #                        cv = np.std(all_amps)/np.mean(all_amps)
                            #
                            #                        # weight parts of the trace during fitting
                            dt = avg_trace.dt
                            weight = np.ones(
                                len(avg_trace.data
                                    )) * 10.  #set everything to ten initially
                            weight[int(10e-3 / dt):int(
                                12e-3 / dt)] = 0.  #area around stim artifact
                            weight[int(12e-3 / dt):int(
                                19e-3 / dt)] = 30.  #area around steep PSP rise

                            # check if the test data dir is there and if not create it
                            test_data_dir = 'test_psp_fit'
                            if not os.path.isdir(test_data_dir):
                                os.mkdir(test_data_dir)

                            save_dict = {}
                            save_dict['input'] = {
                                'data': avg_trace.data.tolist(),
                                'dtype': str(avg_trace.data.dtype),
                                'dt': float(avg_trace.dt),
                                'amp_sign': amp_sign,
                                'yoffset': 0,
                                'xoffset': 14e-3,
                                'avg_amp': float(avg_amp),
                                'method': 'leastsq',
                                'stacked': False,
                                'rise_time_mult_factor': 10.,
                                'weight': weight.tolist()
                            }

                            # need to remake trace because different output is created
                            avg_trace_simple = TSeries(
                                data=np.array(save_dict['input']['data']),
                                dt=save_dict['input']
                                ['dt'])  # create TSeries object

                            psp_fits_original = fit_psp(
                                avg_trace,
                                sign=save_dict['input']['amp_sign'],
                                yoffset=save_dict['input']['yoffset'],
                                xoffset=save_dict['input']['xoffset'],
                                amp=save_dict['input']['avg_amp'],
                                method=save_dict['input']['method'],
                                stacked=save_dict['input']['stacked'],
                                rise_time_mult_factor=save_dict['input']
                                ['rise_time_mult_factor'],
                                fit_kws={
                                    'weights': save_dict['input']['weight']
                                })

                            psp_fits_simple = fit_psp(
                                avg_trace_simple,
                                sign=save_dict['input']['amp_sign'],
                                yoffset=save_dict['input']['yoffset'],
                                xoffset=save_dict['input']['xoffset'],
                                amp=save_dict['input']['avg_amp'],
                                method=save_dict['input']['method'],
                                stacked=save_dict['input']['stacked'],
                                rise_time_mult_factor=save_dict['input']
                                ['rise_time_mult_factor'],
                                fit_kws={
                                    'weights': save_dict['input']['weight']
                                })
                            print expt.uid, pre, post
                            if psp_fits_original.nrmse(
                            ) != psp_fits_simple.nrmse():
                                print '  the nrmse values dont match'
                                print '\toriginal', psp_fits_original.nrmse()
                                print '\tsimple', psp_fits_simple.nrmse()
예제 #23
0
def plot_features(organism=None,
                  conn_type=None,
                  calcium=None,
                  age=None,
                  sweep_thresh=None,
                  fit_thresh=None):
    s = db.session()

    filters = {
        'organism': organism,
        'conn_type': conn_type,
        'calcium': calcium,
        'age': age
    }

    selection = [{}]
    for key, value in filters.iteritems():
        if value is not None:
            temp_list = []
            value_list = value.split(',')
            for v in value_list:
                temp = [s1.copy() for s1 in selection]
                for t in temp:
                    t[key] = v
                temp_list = temp_list + temp
            selection = list(temp_list)

    if len(selection) > 0:

        response_grid = PlotGrid()
        response_grid.set_shape(len(selection), 1)
        response_grid.show()
        feature_grid = PlotGrid()
        feature_grid.set_shape(6, 1)
        feature_grid.show()

        for i, select in enumerate(selection):
            pre_cell = aliased(db.Cell)
            post_cell = aliased(db.Cell)
            q_filter = []
            if sweep_thresh is not None:
                q_filter.append(FirstPulseFeatures.n_sweeps >= sweep_thresh)
            species = select.get('organism')
            if species is not None:
                q_filter.append(db.Slice.species == species)
            c_type = select.get('conn_type')
            if c_type is not None:
                pre_type, post_type = c_type.split('-')
                pre_layer, pre_cre = pre_type.split(';')
                if pre_layer == 'None':
                    pre_layer = None
                post_layer, post_cre = post_type.split(';')
                if post_layer == 'None':
                    post_layer = None
                q_filter.extend([
                    pre_cell.cre_type == pre_cre,
                    pre_cell.target_layer == pre_layer,
                    post_cell.cre_type == post_cre,
                    post_cell.target_layer == post_layer
                ])
            calc_conc = select.get('calcium')
            if calc_conc is not None:
                q_filter.append(db.Experiment.acsf.like(calc_conc + '%'))
            age_range = select.get('age')
            if age_range is not None:
                age_lower, age_upper = age_range.split('-')
                q_filter.append(
                    db.Slice.age.between(int(age_lower), int(age_upper)))

            q = s.query(FirstPulseFeatures).join(db.Pair, FirstPulseFeatures.pair_id==db.Pair.id)\
                .join(pre_cell, db.Pair.pre_cell_id==pre_cell.id)\
                .join(post_cell, db.Pair.post_cell_id==post_cell.id)\
                .join(db.Experiment, db.Experiment.id==db.Pair.expt_id)\
                .join(db.Slice, db.Slice.id==db.Experiment.slice_id)

            for filter_arg in q_filter:
                q = q.filter(filter_arg)

            results = q.all()

            trace_list = []
            for pair in results:
                #TODO set t0 to latency to align to foot of PSP
                trace = TSeries(data=pair.avg_psp,
                                sample_rate=db.default_sample_rate)
                trace_list.append(trace)
                response_grid[i, 0].plot(trace.time_values, trace.data)
            if len(trace_list) > 0:
                grand_trace = TSeriesList(trace_list).mean()
                response_grid[i, 0].plot(grand_trace.time_values,
                                         grand_trace.data,
                                         pen='b')
                response_grid[i, 0].setTitle(
                    'layer %s, %s-> layer %s, %s; n_synapses = %d' %
                    (pre_layer, pre_cre, post_layer, post_cre,
                     len(trace_list)))
            else:
                print('No synapses for layer %s, %s-> layer %s, %s' %
                      (pre_layer, pre_cre, post_layer, post_cre))

    return response_grid, feature_grid
예제 #24
0
    def plot_element_data(self, pre_class, post_class, element, field_name, color='g', trace_plt=None):
        summary = element.agg(self.summary_stat)  
        val = summary[field_name]['metric_summary']
        line = pg.InfiniteLine(val, pen={'color': color, 'width': 2}, movable=False)
        scatter = None
        tracesA = []
        tracesB = []
        connections = element[element['Connected'] == True].index.tolist()
        for pair in connections:
            # rsf = pair.resting_state_fit
            synapse = pair.synapse
            if synapse is None:
                continue
            arfs = pair.avg_response_fits
            latency = pair.synapse.latency
            syn_typ = pair.synapse.synapse_type
            self.pair_items[pair.id] = []
            trace_itemA = None
            trace_itemB = None
            # if rsf is not None:
            #     traceA = TSeries(data=rsf.ic_avg_data, sample_rate=db.default_sample_rate)
            #     start_time = rsf.ic_avg_data_start_time
            #     if latency is not None and start_time is not None:
            #         xoffset = start_time - latency
            #         baseline_window = [abs(xoffset)-1e-3, abs(xoffset)]
            #         traceA = format_trace(traceA, baseline_window, x_offset=xoffset, align='psp')
            #         trace_itemA = trace_plt[0].plot(traceA.time_values, traceA.data)
            #         trace_itemA.pair = pair
            #         trace_itemA.curve.setClickable(True)
            #         trace_itemA.sigClicked.connect(self.trace_plot_clicked)
            #         self.pair_items[pair.id].append(trace_itemA)
            #         tracesA.append(traceA)
            if arfs is not None:
                for arf in arfs:
                    if arf.holding in syn_typ_holding[syn_typ] and arf.manual_qc_pass is True and latency is not None:
                        if arf.clamp_mode == 'vc' and trace_itemA is None:
                            traceA = TSeries(data=arf.avg_data, sample_rate=db.default_sample_rate)
                            traceA = bessel_filter(traceA, 5000, btype='low', bidir=True)
                            start_time = arf.avg_data_start_time
                            if start_time is not None:
                                xoffset = start_time - latency
                                baseline_window = [abs(xoffset)-1e-3, abs(xoffset)]
                                traceA = format_trace(traceA, baseline_window, x_offset=xoffset, align='psp')
                                trace_itemA = trace_plt[0].plot(traceA.time_values, traceA.data)
                                trace_itemA.pair = pair
                                trace_itemA.curve.setClickable(True)
                                trace_itemA.sigClicked.connect(self.trace_plot_clicked)
                                self.pair_items[pair.id].append(trace_itemA)
                                tracesA.append(traceA)
                        if arf.clamp_mode == 'ic' and trace_itemB is None:
                            traceB = TSeries(data=arf.avg_data, sample_rate=db.default_sample_rate)
                            start_time = arf.avg_data_start_time
                            if latency is not None and start_time is not None:
                                xoffset = start_time - latency
                                baseline_window = [abs(xoffset)-1e-3, abs(xoffset)]
                                traceB = format_trace(traceB, baseline_window, x_offset=xoffset, align='psp')
                                trace_itemB = trace_plt[1].plot(traceB.time_values, traceB.data)
                                trace_itemB.pair = pair
                                trace_itemB.curve.setClickable(True)
                                trace_itemB.sigClicked.connect(self.trace_plot_clicked)
                                tracesB.append(traceB)
            self.pair_items[pair.id] = [trace_itemA, trace_itemB]

        if len(tracesA) > 0:
            grand_trace = TSeriesList(tracesA).mean()
            name = ('%s->%s' % (pre_class, post_class))
            # trace_plt[0].addLegend()
            trace_plt[0].plot(grand_trace.time_values, grand_trace.data, pen={'color': color, 'width': 3}, name=name)
            trace_plt[0].setXRange(-5e-3, 20e-3)
            trace_plt[0].setLabels(left=('', 'A'), bottom=('Response Onset', 's'))
            trace_plt[0].setTitle('Voltage Clamp')
        if len(tracesB) > 0:
            grand_trace = TSeriesList(tracesB).mean()
            trace_plt[1].plot(grand_trace.time_values, grand_trace.data, pen={'color': color, 'width': 3})
            trace_plt[1].setLabels(right=('', 'V'), bottom=('Response Onset', 's'))
            trace_plt[1].setTitle('Current Clamp')
        return line, scatter