def simulate_response(fg_recs, bg_results, amp, rtime, seed=None):
    if seed is not None:
        np.random.seed(seed)

    dt = 1.0 / db.default_sample_rate
    t = np.arange(0, 15e-3, dt)
    template = Psp.psp_func(t, xoffset=0, yoffset=0, rise_time=rtime, decay_tau=15e-3, amp=1, rise_power=2)

    r_amps = scipy.stats.binom.rvs(p=0.2, n=24, size=len(fg_recs)) * scipy.stats.norm.rvs(scale=0.3, loc=1, size=len(fg_recs))
    r_amps *= amp / r_amps.mean()
    r_latency = np.random.normal(size=len(fg_recs), scale=200e-6, loc=13e-3)
    fg_results = []
    traces = []
    fg_recs = [RecordWrapper(rec) for rec in fg_recs]  # can't modify fg_recs, so we wrap records with a mutable shell
    for k,rec in enumerate(fg_recs):
        rec.data = rec.data.copy()
        start = int(r_latency[k] * db.default_sample_rate)
        length = len(rec.data) - start
        rec.data[start:] += template[:length] * r_amps[k]

        fg_result = analyze_response_strength(rec, 'baseline')
        fg_results.append(fg_result)

        traces.append(Trace(rec.data, sample_rate=db.default_sample_rate))
        traces[-1].amp = r_amps[k]
    fg_results = str_analysis_result_table(fg_results, fg_recs)
    conn_result = analyze_pair_connectivity({('ic', 'fg'): fg_results, ('ic', 'bg'): bg_results, ('vc', 'fg'): [], ('vc', 'bg'): []}, sign=1)
    return conn_result, traces
Example #2
0
def simulate_response(fg_recs, bg_results, amp, rtime, seed=None):
    if seed is not None:
        np.random.seed(seed)

    dt = 1.0 / db.default_sample_rate
    t = np.arange(0, 15e-3, dt)
    template = Psp.psp_func(t, xoffset=0, yoffset=0, rise_time=rtime, decay_tau=15e-3, amp=1, rise_power=2)

    r_amps = scipy.stats.binom.rvs(p=0.2, n=24, size=len(fg_recs)) * scipy.stats.norm.rvs(scale=0.3, loc=1, size=len(fg_recs))
    r_amps *= amp / r_amps.mean()
    r_latency = np.random.normal(size=len(fg_recs), scale=200e-6, loc=13e-3)
    fg_results = []
    traces = []
    fg_recs = [RecordWrapper(rec) for rec in fg_recs]  # can't modify fg_recs, so we wrap records with a mutable shell
    for k,rec in enumerate(fg_recs):
        rec.data = rec.data.copy()
        start = int(r_latency[k] * db.default_sample_rate)
        length = len(rec.data) - start
        rec.data[start:] += template[:length] * r_amps[k]

        fg_result = analyze_response_strength(rec, 'baseline')
        fg_results.append(fg_result)

        traces.append(Trace(rec.data, sample_rate=db.default_sample_rate))
        traces[-1].amp = r_amps[k]
    fg_results = str_analysis_result_table(fg_results, fg_recs)
    conn_result = analyze_pair_connectivity({('ic', 'fg'): fg_results, ('ic', 'bg'): bg_results, ('vc', 'fg'): [], ('vc', 'bg'): []}, sign=1)
    return conn_result, traces
def test_exp_deconv_psp_params():
    from neuroanalysis.event_detection import exp_deconvolve, exp_deconv_psp_params
    from neuroanalysis.data import TSeries
    from neuroanalysis.fitting import Psp

    x = np.linspace(0, 0.02, 10000)
    amp = 1
    rise_time = 4e-3
    decay_tau = 10e-3
    rise_power = 2

    # Make a PSP waveform
    psp = Psp()
    y = psp.eval(x=x,
                 xoffset=0,
                 yoffset=0,
                 amp=amp,
                 rise_time=rise_time,
                 decay_tau=decay_tau,
                 rise_power=rise_power)

    # exponential deconvolution
    y_ts = TSeries(y, time_values=x)
    y_deconv = exp_deconvolve(y_ts, decay_tau).data

    # show that we can get approximately the same result using exp_deconv_psp_params
    d_amp, d_rise_time, d_rise_power, d_decay_tau = exp_deconv_psp_params(
        amp, rise_time, rise_power, decay_tau)
    y2 = psp.eval(x=x,
                  xoffset=0,
                  yoffset=0,
                  amp=d_amp,
                  rise_time=d_rise_time,
                  decay_tau=d_decay_tau,
                  rise_power=d_rise_power)

    assert np.allclose(y_deconv, y2[1:], atol=0.02)
Example #4
0
    def load_saved_fit(self, record):
        data = record.notes        
        pair_params = {'Synapse call': data['synapse_type'], 'Gap junction call': data['gap_junction']}
        self.ctrl_panel.update_user_params(**pair_params)
        self.warnings = data.get('fit_warnings', [])
        self.ctrl_panel.output_params.child('Warnings').setValue('\n'.join(self.warnings))
        self.ctrl_panel.output_params.child('Comments', '').setValue(data.get('comments', ''))

        # some records may be stored with no fit if a synapse is not present.
        if 'fit_parameters' not in data:
            return
            
        self.fit_params = data['fit_parameters']
        self.ctrl_panel.update_fit_params(data['fit_parameters']['fit'])
        self.output_fit_parameters = data['fit_parameters']['fit']
        self.initial_fit_parameters = data['fit_parameters']['initial']

        initial_vc_latency = (
            data['fit_parameters']['initial']['vc']['-55'].get('xoffset') or
            data['fit_parameters']['initial']['vc']['-70'].get('xoffset') or
            1e-3
        )
        initial_ic_latency = (
            data['fit_parameters']['initial']['ic']['-55'].get('xoffset') or
            data['fit_parameters']['initial']['ic']['-70'].get('xoffset') or
            1e-3
        )

        latency_diff = np.diff([initial_vc_latency, initial_ic_latency])[0]
        if abs(latency_diff) < 100e-6:
            self.latency_superline.set_value(initial_vc_latency, block_fit=True)
        else:
            fit_pass_vc = [data['fit_pass']['vc'][str(h)] for h in holdings]
            fit_pass_ic = [data['fit_pass']['ic'][str(h)] for h in holdings]
            if any(fit_pass_vc):
                self.latency_superline.set_value(initial_vc_latency, block_fit=True)
            elif any(fit_pass_ic):
                self.latency_superline.set_value(initial_ic_latency, block_fit=True)
            else:
                self.latency_superline.set_value(initial_vc_latency, block_fit=True)

        for mode in modes:
            for holding in holdings:
                fit_pass = data['fit_pass'][mode][str(holding)]
                self.ctrl_panel.output_params.child('Fit parameters', str(holding) + ' ' + mode.upper(), 'Fit Pass').setValue(fit_pass)
                fit_params = copy.deepcopy(data['fit_parameters']['fit'][mode][str(holding)])
                
                if fit_params:
                    fit_params.pop('nrmse', None)
                    if mode == 'ic':
                        p = StackedPsp() 
                    if mode == 'vc':
                        p = Psp()
                        
                    # make a list of spike-aligned postsynaptic tseries
                    tsl = PulseResponseList(self.sorted_responses[mode, holding]['qc_pass']).post_tseries(align='spike', bsub=True)
                    if len(tsl) == 0:
                        continue
                    
                    # average all together
                    avg = tsl.mean()
                    
                    fit_params.setdefault('exp_tau', fit_params['decay_tau'])
                    fit_psp = p.eval(x=avg.time_values, **fit_params)
                    fit_tseries = avg.copy(data=fit_psp)
                    
                    if mode == 'vc':
                        self.vc_plot.plot_fit(fit_tseries, holding, fit_pass=fit_pass)
                    if mode == 'ic':
                        self.ic_plot.plot_fit(fit_tseries, holding, fit_pass=fit_pass)            
    def add_connection_plots(i, name, timestamp, pre_id, post_id):
        global session, win, filtered
        p = pg.debug.Profiler(disabled=True, delayed=False)
        trace_plot = win.addPlot(i, 1)
        trace_plots.append(trace_plot)
        deconv_plot = win.addPlot(i, 2)
        deconv_plots.append(deconv_plot)
        hist_plot = win.addPlot(i, 3)
        hist_plots.append(hist_plot)
        limit_plot = win.addPlot(i, 4)
        limit_plot.addLegend()
        limit_plot.setLogMode(True, True)
        # Find this connection in the pair list
        idx = np.argwhere((abs(filtered['acq_timestamp'] - timestamp) < 1) & (filtered['pre_cell_id'] == pre_id) & (filtered['post_cell_id'] == post_id))
        if idx.size == 0:
            print("not in filtered connections")
            return
        idx = idx[0,0]
        p()

        # Mark the point in scatter plot
        scatter_plot.plot([background[idx]], [signal[idx]], pen='k', symbol='o', size=10, symbolBrush='r', symbolPen=None)
            
        # Plot example traces and histograms
        for plts in [trace_plots, deconv_plots]:
            plt = plts[-1]
            plt.setXLink(plts[0])
            plt.setYLink(plts[0])
            plt.setXRange(-10e-3, 17e-3, padding=0)
            plt.hideAxis('left')
            plt.hideAxis('bottom')
            plt.addLine(x=0)
            plt.setDownsampling(auto=True, mode='peak')
            plt.setClipToView(True)
            hbar = pg.QtGui.QGraphicsLineItem(0, 0, 2e-3, 0)
            hbar.setPen(pg.mkPen(color='k', width=5))
            plt.addItem(hbar)
            vbar = pg.QtGui.QGraphicsLineItem(0, 0, 0, 100e-6)
            vbar.setPen(pg.mkPen(color='k', width=5))
            plt.addItem(vbar)


        hist_plot.setXLink(hist_plots[0])
        
        pair = session.query(db.Pair).filter(db.Pair.id==filtered[idx]['pair_id']).all()[0]
        p()
        amps = strength_analysis.get_amps(session, pair)
        p()
        base_amps = strength_analysis.get_baseline_amps(session, pair)
        p()
        
        q = strength_analysis.response_query(session)
        p()
        q = q.join(strength_analysis.PulseResponseStrength)
        q = q.filter(strength_analysis.PulseResponseStrength.id.in_(amps['id']))
        q = q.join(db.Recording, db.Recording.id==db.PulseResponse.recording_id).join(db.PatchClampRecording).join(db.MultiPatchProbe)
        q = q.filter(db.MultiPatchProbe.induction_frequency < 100)
        # pre_cell = db.aliased(db.Cell)
        # post_cell = db.aliased(db.Cell)
        # q = q.join(db.Pair).join(db.Experiment).join(pre_cell, db.Pair.pre_cell_id==pre_cell.id).join(post_cell, db.Pair.post_cell_id==post_cell.id)
        # q = q.filter(db.Experiment.id==filtered[idx]['experiment_id'])
        # q = q.filter(pre_cell.ext_id==pre_id)
        # q = q.filter(post_cell.ext_id==post_id)

        fg_recs = q.all()
        p()

        traces = []
        deconvs = []
        for rec in fg_recs[:100]:
            result = strength_analysis.analyze_response_strength(rec, source='pulse_response', lpf=True, lowpass=2000,
                                                remove_artifacts=False, bsub=True)
            trace = result['raw_trace']
            trace.t0 = -result['spike_time']
            trace = trace - np.median(trace.time_slice(-0.5e-3, 0.5e-3).data)
            traces.append(trace)            
            trace_plot.plot(trace.time_values, trace.data, pen=(0, 0, 0, 20))

            trace = result['dec_trace']
            trace.t0 = -result['spike_time']
            trace = trace - np.median(trace.time_slice(-0.5e-3, 0.5e-3).data)
            deconvs.append(trace)            
            deconv_plot.plot(trace.time_values, trace.data, pen=(0, 0, 0, 20))

        # plot average trace
        mean = TraceList(traces).mean()
        trace_plot.plot(mean.time_values, mean.data, pen={'color':'g', 'width': 2}, shadowPen={'color':'k', 'width': 3}, antialias=True)
        mean = TraceList(deconvs).mean()
        deconv_plot.plot(mean.time_values, mean.data, pen={'color':'g', 'width': 2}, shadowPen={'color':'k', 'width': 3}, antialias=True)

        # add label
        label = pg.LabelItem(name)
        label.setParentItem(trace_plot)


        p("analyze_response_strength")

        # bins = np.arange(-0.0005, 0.002, 0.0001) 
        # field = 'pos_amp'
        bins = np.arange(-0.001, 0.015, 0.0005) 
        field = 'pos_dec_amp'
        n = min(len(amps), len(base_amps))
        hist_y, hist_bins = np.histogram(base_amps[:n][field], bins=bins)
        hist_plot.plot(hist_bins, hist_y, stepMode=True, pen=None, brush=(200, 0, 0, 150), fillLevel=0)
        hist_y, hist_bins = np.histogram(amps[:n][field], bins=bins)
        hist_plot.plot(hist_bins, hist_y, stepMode=True, pen='k', brush=(0, 150, 150, 100), fillLevel=0)
        p()

        pg.QtGui.QApplication.processEvents()


        # Plot detectability analysis
        q = strength_analysis.baseline_query(session)
        q = q.join(strength_analysis.BaselineResponseStrength)
        q = q.filter(strength_analysis.BaselineResponseStrength.id.in_(base_amps['id']))
        # q = q.limit(100)
        bg_recs = q.all()

        def clicked(sp, pts):
            traces = pts[0].data()['traces']
            print([t.amp for t in traces])
            plt = pg.plot()
            bsub = [t.copy(data=t.data - np.median(t.time_slice(0, 1e-3).data)) for t in traces]
            for t in bsub:
                plt.plot(t.time_values, t.data, pen=(0, 0, 0, 50))
            mean = TraceList(bsub).mean()
            plt.plot(mean.time_values, mean.data, pen='g')

        # first measure background a few times
        N = len(fg_recs)
        N = 50  # temporary for testing
        print("Testing %d trials" % N)


        bg_results = []
        M = 500
        print("  Grinding on %d background trials" % len(bg_recs))
        for i in range(M):
            amps = base_amps.copy()
            np.random.shuffle(amps)
            bg_results.append(np.median(amps[:N]['pos_dec_amp']) / np.std(amps[:N]['pos_dec_latency']))
            print("    %d/%d      \r" % (i, M),)
        print("    done.            ")
        print("    ", bg_results)


        # now measure foreground simulated under different conditions
        amps = 5e-6 * 2**np.arange(6)
        amps[0] = 0
        rtimes = 1e-3 * 1.71**np.arange(4)
        dt = 1 / db.default_sample_rate
        results = np.empty((len(amps), len(rtimes)), dtype=[('pos_dec_amp', float), ('latency_stdev', float), ('result', float), ('percentile', float), ('traces', object)])
        print("  Simulating synaptic events..")
        for j,rtime in enumerate(rtimes):
            for i,amp in enumerate(amps):
                trial_results = []
                t = np.arange(0, 15e-3, dt)
                template = Psp.psp_func(t, xoffset=0, yoffset=0, rise_time=rtime, decay_tau=15e-3, amp=1, rise_power=2)

                for l in range(20):
                    print("    %d/%d  %d/%d      \r" % (i,len(amps),j,len(rtimes)),)
                    r_amps = amp * 2**np.random.normal(size=N, scale=0.5)
                    r_latency = np.random.normal(size=N, scale=600e-6, loc=12.5e-3)
                    fg_results = []
                    traces = []
                    np.random.shuffle(bg_recs)
                    for k,rec in enumerate(bg_recs[:N]):
                        data = rec.data.copy()
                        start = int(r_latency[k] / dt)
                        length = len(rec.data) - start
                        rec.data[start:] += template[:length] * r_amps[k]

                        fg_result = strength_analysis.analyze_response_strength(rec, 'baseline')
                        fg_results.append((fg_result['pos_dec_amp'], fg_result['pos_dec_latency']))

                        traces.append(Trace(rec.data.copy(), dt=dt))
                        traces[-1].amp = r_amps[k]
                        rec.data[:] = data  # can't modify rec, so we have to muck with the array (and clean up afterward) instead
                    
                    fg_amp = np.array([r[0] for r in fg_results])
                    fg_latency = np.array([r[1] for r in fg_results])
                    trial_results.append(np.median(fg_amp) / np.std(fg_latency))
                results[i,j]['result'] = np.median(trial_results) / np.median(bg_results)
                results[i,j]['percentile'] = stats.percentileofscore(bg_results, results[i,j]['result'])
                results[i,j]['traces'] = traces

            assert all(np.isfinite(results[i]['pos_dec_amp']))
            print(i, results[i]['result'])
            print(i, results[i]['percentile'])
            

            # c = limit_plot.plot(rtimes, results[i]['result'], pen=(i, len(amps)*1.3), symbol='o', antialias=True, name="%duV"%(amp*1e6), data=results[i], symbolSize=4)
            # c.scatter.sigClicked.connect(clicked)
            # pg.QtGui.QApplication.processEvents()
            c = limit_plot.plot(amps, results[:,j]['result'], pen=(j, len(rtimes)*1.3), symbol='o', antialias=True, name="%dus"%(rtime*1e6), data=results[:,j], symbolSize=4)
            c.scatter.sigClicked.connect(clicked)
            pg.QtGui.QApplication.processEvents()

                
        pg.QtGui.QApplication.processEvents()
Example #6
0
from neuroanalysis.fitting import Psp
from neuroanalysis.ui.fitting import FitExplorer

pg.mkQApp()
pg.dbg()

# Load PSP data from the test_data repository
if len(sys.argv) == 1:
    data_file = 'test_data/test_psp_fit/1485904693.10_8_2NOTstacked.json'
else:
    data_file = sys.argv[1]

data = json.load(open(data_file))

y = np.array(data['input']['data'])
x = np.arange(len(y)) * data['input']['dt']

psp = Psp()
params = OrderedDict([
    ('xoffset', (10e-3, 10e-3, 15e-3)),
    ('yoffset', 0),
    ('amp', 0.1e-3),
    ('rise_time', (2e-3, 500e-6, 10e-3)),
    ('decay_tau', (4e-3, 1e-3, 50e-3)),
    ('rise_power', (2.0, 'fixed')),
])
fit = psp.fit(y, x=x, xtol=1e-3, maxfev=100, params=params)

x = FitExplorer(fit=fit)
x.show()
def fit_psp(response, mode='ic', sign='any', xoffset=(11e-3, 10e-3, 15e-3), yoffset=(0, 'fixed'),
            mask_stim_artifact=True, method='leastsq', fit_kws=None, stacked=True,
            rise_time_mult_factor=2., **kwds):
    t = response.time_values
    y = response.data

    if mode == 'ic':
        amp = .2e-3
        amp_max = 100e-3
        rise_time = 5e-3
        decay_tau = 50e-3
    elif mode == 'vc':
        amp = 20e-12
        amp_max = 500e-12
        rise_time = 1e-3
        decay_tau = 4e-3
    else:
        raise ValueError('mode must be "ic" or "vc"')

    amps = [(amp, 0, amp_max), (-amp, -amp_max, 0)]
    if sign == '-':
        amps = amps[1:]
    elif sign == '+':
        amps = amps[:1]
    elif sign != 'any':
        raise ValueError('sign must be "+", "-", or "any"')

    psp = StackedPsp()
    if stacked:
        psp = StackedPsp()
    else:
        psp = Psp()
    
    # initial condition, lower boundry, upper boundry    
    base_params = {
        'xoffset': xoffset,
        'yoffset': yoffset,
        'rise_time': (rise_time, rise_time/rise_time_mult_factor, rise_time*rise_time_mult_factor),
        'decay_tau': (decay_tau, decay_tau/10., decay_tau*10.),
        'rise_power': (2, 'fixed'),
    }
    
    if stacked:
        base_params.update({
            'exp_amp': 'amp * amp_ratio',
            'amp_ratio': (0, -100, 100),
        })  
        
    if 'rise_time' in kwds:
        rt = kwds.pop('rise_time')
        if not isinstance(rt, tuple):
            rt = (rt, rt/2., rt*2.)
        base_params['rise_time'] = rt
                
    if 'decay_tau' in kwds:
        dt = kwds.pop('decay_tau')
        if not isinstance(dt, tuple):
            dt = (dt, dt/2., dt*2.)
        base_params['decay_tau'] = dt
                
    base_params.update(kwds)
    
    params = []
    for amp, amp_min, amp_max in amps:
        p2 = base_params.copy()
        p2['amp'] = (amp, amp_min, amp_max)
        params.append(p2)

    dt = response.dt
    weight = np.ones(len(y))
    #weight[:int(10e-3/dt)] = 0.5
    
    init_xoff = xoffset[0] if isinstance(xoffset, tuple) else xoffset
    onset_index = int((init_xoff-response.t0) / dt)
    weight[onset_index+int(1e-3/dt):onset_index+int(7e-3/dt)] = 3
    if mask_stim_artifact:
        # Use zero weight for fit region around the stimulus artifact
        i2 = onset_index + int(1e-3 / dt)
        i1 = i2 - int(2e-3 / dt)
        weight[i1:i2] = 0

    if fit_kws is None:
        fit_kws = {'xtol': 1e-4, 'maxfev': 300, 'nan_policy': 'omit'}

    if 'weights' not in fit_kws:
        fit_kws['weights'] = weight
    
    best_fit = None
    best_score = None
    for p in params:
        try:
            fit = psp.fit(y, x=t, params=p, fit_kws=fit_kws, method=method)
        except Exception:
            if p is params[-1]:
                raise
            continue
        err = np.sum(fit.residual**2)
        if best_fit is None or err < best_score:
            best_fit = fit
            best_score = err
    fit = best_fit

    # nrmse = fit.nrmse()
    if 'baseline_std' in response.meta:
        fit.snr = abs(fit.best_values['amp']) / response.meta['baseline_std']
        fit.err = fit.rmse() / response.meta['baseline_std']
    # print fit.best_values
    # print "RMSE:", fit.rmse()
    # print "NRMSE:", nrmse
    # print "SNR:", snr

    return fit
def measure_deconvolved_response(pr):
    """Use exponential deconvolution and a curve fit to estimate the amplitude of a synaptic response.

    Uses the known latency and kinetics of the synapse to constrain the fit.
    Optionally fit a baseline at the same time for noise measurement.
    
    Parameters
    ----------
    pr : PulseResponse
    """
    syn = pr.pair.synapse
    pcr = pr.recording.patch_clamp_recording
    if pcr.clamp_mode == 'ic':
        rise_time = syn.psp_rise_time
        decay_tau = syn.psp_decay_tau
        lowpass = 2000
    else:
        rise_time = syn.psc_rise_time
        decay_tau = syn.psc_decay_tau
        lowpass = 6000

    # make sure all parameters are available
    for v in [
            pr.stim_pulse.first_spike_time, syn.latency, rise_time, decay_tau
    ]:
        if v is None or not np.isfinite(v):
            return None, None

    response_data = pr.get_tseries('post', align_to='spike')
    baseline_data = pr.get_tseries('baseline', align_to='spike')

    ret = []
    for data in (response_data, baseline_data):
        if data is None:
            ret.append(None)
            continue

        filtered = deconv_filter(data,
                                 None,
                                 tau=decay_tau,
                                 lowpass=lowpass,
                                 remove_artifacts=False,
                                 bsub=True)

        # chop down to the minimum we need to fit the deconvolved event.
        # there's a tradeoff here -- to much data and we risk incorporating nearby spontaneous events; too little
        # data and we get more noise in the fit to baseline
        filtered = filtered.time_slice(syn.latency - 1e-3,
                                       syn.latency + rise_time + 1e-3)

        # Deconvolving a PSP-like shape yields a narrower PSP-like shape with lower rise power.
        # Guess the deconvolved time constants:
        dec_amp, dec_rise_time, dec_rise_power, dec_decay_tau = exp_deconv_psp_params(
            amp=1, rise_time=rise_time, decay_tau=decay_tau, rise_power=2)
        amp_ratio = 1 / dec_amp

        psp = Psp()

        # Need to measure amplitude of exp-deconvolved events; two methods to pick from here:
        # 1) Direct curve fitting using the expected deconvolved rise/decay time constants. This
        #    allows some wiggle room in latency, but produces a weird butterfly-shaped background noise distribution.
        # 2) Analytically calculate the scale/offset of a fixed template. Uses a fixed latency, but produces
        #    a nice, normal-looking background noise distribution.

        # Measure amplitude of deconvolved event by curve fitting:
        # with warnings.catch_warnings():
        #     warnings.simplefilter("ignore")
        #     max_amp = filtered.data.max() - filtered.data.min()
        #     fit = psp.fit(filtered.data, x=filtered.time_values, params={
        #         'xoffset': (response_rec.latency, response_rec.latency-0.2e-3, response_rec.latency+0.5e-3),
        #         'yoffset': (0, 'fixed'),
        #         'amp': (0, -max_amp, max_amp),
        #         'rise_time': (dec_rise_time, 'fixed'),
        #         'decay_tau': (dec_decay_tau, 'fixed'),
        #         'rise_power': (dec_rise_power, 'fixed'),
        #     })
        # reconvolved_amp = fit.best_values['amp'] * amp_ratio

        # fit = {
        #     'xoffset': fit.best_values['xoffset'],
        #     'yoffset': fit.best_values['yoffset'],
        #     'amp': fit.best_values['amp'],
        #     'rise_time': dec_rise_time,
        #     'decay_tau': dec_decay_tau,
        #     'rise_power': dec_rise_power,
        #     'reconvolved_amp': reconvolved_amp,
        # }

        # Measure amplitude of deconvolved events by direct template match
        template = psp.eval(
            x=filtered.time_values,
            xoffset=syn.latency,
            yoffset=0,
            amp=1,
            rise_time=dec_rise_time,
            decay_tau=dec_decay_tau,
            rise_power=dec_rise_power,
        )
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            scale, offset = fit_scale_offset(filtered.data, template)

        # calculate amplitude of reconvolved event -- tis is our best guess as to the
        # actual event amplitude
        reconvolved_amp = scale * amp_ratio

        fit = {
            'xoffset': syn.latency,
            'yoffset': offset,
            'amp': scale,
            'rise_time': dec_rise_time,
            'decay_tau': dec_decay_tau,
            'rise_power': 1,
            'reconvolved_amp': reconvolved_amp,
        }

        ret.append(fit)

    return ret
def fit_psp(
    response,
    mode='ic',
    sign='any',  #Note this will not be used if *amp* input is specified
    method='leastsq',
    fit_kws=None,
    stacked=True,
    rise_time_mult_factor=10.,  #Note this will not be used if *rise_time* input is specified 
    weight='default',
    amp_ratio='default',
    # the following are parameters that can be fit
    amp='default',
    decay_tau='default',
    rise_power='default',
    rise_time='default',
    exp_amp='default',
    xoffset='default',
    yoffset='default',
):
    """Fit psp. function to the equation 
    
    This function make assumptions about where the cross talk happens as traces 
    have been aligned to the pulses and response.t0 has been set to zero 
    
    Parameters
    ----------
    response : neuroanalysis.data.Trace class
        Contains data on trace waveform.
    mode : string
        either 'ic' for current clamp or 'vc' for voltage clamp
    sign : string
        Specifies the sign of the PSP deflection.  Must be '+', '-', or any. If *amp* 
        is specified, value will be irrelevant.
    method : string 
        Method lmfit uses for optimization
    rise_time_mult_factor: float
        Parameter that goes into the default calculation rise time.  
        Note that if an input for *rise_time* is provided this input
        will be irrelevant.
    stacked : True or False
        If true, use the StackedPsp function which assumes there is still left
        over voltage decay from previous events.  If False, use Psp function
        which assumes the region of the waveform before the event is at baseline.
    fit_kws : dictionary
        Additional key words that are fed to lmfit
    exp_amp : string
        function that is fed to lmfit
    The parameters below are fed to the psp function. Each value in the 
        key:value dictionary pair must be a tuple.
        In general the structure of the tuple is of the form, 
        (initial conditions, lower boundary, higher boundary).
        The initial conditions can be either a number or a list 
        of numbers specifying several initial conditions.  The 
        initial condition may also be fixed by replacing the lower 
        higher boundary combination with 'fixed'.    
        Examples:    
            amplitude=(10, 0, 20)
            amplitude=(10, 'fixed')
            amplitude=([5,10, 20], 0, 20)
            amplitude=([5,10, 20], 'fixed') 
        xoffset : scalar
            Horizontal shift between begin (positive shifts to the right)
        yoffset : scalar
            Vertical offset
        rise_time : scalar
            Time from beginning of psp until peak
        decay_tau : scalar
            Decay time constant
        amp : scalar
            The peak value of the psp
        rise_power : scalar
            Exponent for the rising phase; larger values result in a slower activation 
        amp_ratio : scalar 
            if *stacked* this is used to set up the ratio between the 
            residual decay amplitude and the height of the PSP.
    
    Returns
    -------
    fit: lmfit.model.ModelResult
        Best fit
    """

    # extracting these for ease of use
    t = response.time_values
    y = response.data
    dt = response.dt

    # set initial conditions depending on whether in voltage or current clamp
    # note that sign of these will automatically be set later on based on the
    # the *sign* input
    if mode == 'ic':
        amp_init = .2e-3
        amp_max = 100e-3
        rise_time_init = 5e-3
        decay_tau_init = 50e-3
    elif mode == 'vc':
        amp_init = 20e-12
        amp_max = 500e-12
        rise_time_init = 1e-3
        decay_tau_init = 4e-3
    else:
        raise ValueError('mode must be "ic" or "vc"')

    # Set up amplitude initial values and boundaries depending on whether *sign* are positive or negative
    if sign == '-':
        amps = (-amp_init, -amp_max, 0)
    elif sign == '+':
        amps = (amp_init, 0, amp_max)
    elif sign == 'any':
        warnings.warn(
            "You are not specifying the predicted sign of your psp.  This may slow down or mess up fitting"
        )
        amps = (0, -amp_max, amp_max)
    else:
        raise ValueError('sign must be "+", "-", or "any"')

    # initial condition, lower boundry, upper boundry
    base_params = {
        'xoffset': (14e-3, -float('inf'), float('inf')),
        'yoffset': (0, -float('inf'), float('inf')),
        'rise_time': (rise_time_init, rise_time_init / rise_time_mult_factor,
                      rise_time_init * rise_time_mult_factor),
        'decay_tau':
        (decay_tau_init, decay_tau_init / 10., decay_tau_init * 10.),
        'rise_power': (2, 'fixed'),
        'amp':
        amps
    }

    # specify fitting function and set up conditions
    if not isinstance(stacked, bool):
        raise Exception("Stacked must be True or False")
    if stacked:
        psp = StackedPsp()
        base_params.update({
            #TODO: figure out the bounds on these
            'exp_amp': 'amp * amp_ratio',
            'amp_ratio': (0, -100, 100),
        })
    else:
        psp = Psp()

    # override defaults with input
    for bp in base_params.keys():
        if eval(bp) != 'default':
            base_params[bp] = eval(bp)

    # set weighting that
    if weight == 'default':  #use default weighting
        # THIS CODE IS DEPENDENT ON THE DATA BEING INPUT IN A CERTAIN WAY THAT IS NOT TESTED
        weight = np.ones(len(y)) * 10.  #set everything to ten initially
        weight[int(10e-3 / dt):int(12e-3 /
                                   dt)] = 0.  #area around stim artifact
        weight[int(12e-3 / dt):int(19e-3 /
                                   dt)] = 30.  #area around steep PSP rise
    elif weight is False:  #do not weight any part of the stimulus
        weight = np.ones(len(y))
    elif 'weight' in vars():  #works if there is a value specified in weight
        if len(weight) != len(y):
            raise Exception(
                'the weight and array vectors are not the same length')

    # arguement to be passed through to fitting function
    fit_kws = {'weights': weight}

    # convert initial parameters into a list of dictionaries to be consumed by psp.fit()
    param_dict_list = create_all_fit_param_combos(base_params)

    # cycle though different parameters sets and chose best one
    best_fit = None
    best_score = None
    for p in param_dict_list:
        fit = psp.fit(y, x=t, params=p, fit_kws=fit_kws, method=method)
        err = np.sum(
            fit.residual**2
        )  # note: using this because normalized (nrmse) is not necessary to comparing fits within the same data set
        if best_fit is None or err < best_score:
            best_fit = fit
            best_score = err
    fit = best_fit

    # nrmse = fit.nrmse()
    if 'baseline_std' in response.meta:
        fit.snr = abs(fit.best_values['amp']) / response.meta['baseline_std']
        fit.err = fit.nrmse() / response.meta['baseline_std']

    return fit
Example #10
0
import numpy as np
import pyqtgraph as pg
from collections import OrderedDict
from neuroanalysis.fitting import Psp
from neuroanalysis.ui.fitting import FitExplorer

pg.mkQApp()

data = np.loadtxt('psp.csv', delimiter=',', skiprows=1, usecols=[0, 1])
x = data[:, 0]
y = data[:, 1]

psp = Psp()
params = OrderedDict([
    ('xoffset', (2e-3, 5e-4, 5e-3)),
    ('yoffset', 0),
    ('amp', 10e-12),
    ('rise_time', (2e-3, 50e-6, 10e-3)),
    ('decay_tau', (4e-3, 500e-6, 50e-3)),
    ('rise_power', (2.0, 'fixed')),
])
fit = psp.fit(y * 1e12, x=x, xtol=1e-3, maxfev=100, **params)

x = FitExplorer(fit=fit)
x.show()
Example #11
0
    def _plot_pulse_response(self, rec):
        pr = rec.PulseResponse
        base_ts = pr.get_tseries('baseline', align_to='spike')
        if base_ts is not None:
            self.data_plots[1].plot(base_ts.time_values, base_ts.data)

        pre_ts = pr.get_tseries('pre', align_to='spike')

        # If there is no presynaptic spike time, plot spike in red and bail out
        if pre_ts is None:
            pre_ts = pr.get_tseries('pre', align_to='pulse')
            self.spike_plots[0].plot(pre_ts.time_values,
                                     pre_ts.data,
                                     pen=(255, 0, 0, 100))
            return

        post_ts = pr.get_tseries('post', align_to='spike')

        self.spike_plots[0].plot(pre_ts.time_values, pre_ts.data)
        self.data_plots[0].plot(post_ts.time_values, post_ts.data)

        # evaluate recorded fit for this response
        fit_par = rec.PulseResponseFit

        # If there is no fit, bail out here
        if fit_par is None:
            return

        spsp = StackedPsp()
        fit = spsp.eval(
            x=post_ts.time_values,
            exp_amp=fit_par.fit_exp_amp,
            exp_tau=fit_par.fit_decay_tau,
            amp=fit_par.fit_amp,
            rise_time=fit_par.fit_rise_time,
            decay_tau=fit_par.fit_decay_tau,
            xoffset=fit_par.fit_latency,
            yoffset=fit_par.fit_yoffset,
            rise_power=2,
        )
        self.data_plots[0].plot(post_ts.time_values, fit, pen=(0, 255, 0, 100))

        # plot with reconvolved amplitude
        fit = spsp.eval(
            x=post_ts.time_values,
            exp_amp=fit_par.fit_exp_amp,
            exp_tau=fit_par.fit_decay_tau,
            amp=fit_par.dec_fit_reconv_amp,
            rise_time=fit_par.fit_rise_time,
            decay_tau=fit_par.fit_decay_tau,
            xoffset=fit_par.fit_latency,
            yoffset=fit_par.fit_yoffset,
            rise_power=2,
        )
        self.data_plots[0].plot(post_ts.time_values,
                                fit,
                                pen=(200, 255, 0, 100))

        # plot deconvolution
        clamp_mode = pr.recording.patch_clamp_recording.clamp_mode
        if clamp_mode == 'ic':
            decay_tau = self.loaded_pair.synapse.psp_decay_tau
            lowpass = 2000
        else:
            decay_tau = self.loaded_pair.synapse.psc_decay_tau
            lowpass = 6000

        dec = deconv_filter(post_ts,
                            None,
                            tau=decay_tau,
                            lowpass=lowpass,
                            remove_artifacts=False,
                            bsub=True)
        self.dec_plots[0].plot(dec.time_values, dec.data)

        # plot deconvolution fit
        psp = Psp()
        fit = psp.eval(
            x=dec.time_values,
            exp_tau=fit_par.dec_fit_decay_tau,
            amp=fit_par.dec_fit_amp,
            rise_time=fit_par.dec_fit_rise_time,
            decay_tau=fit_par.dec_fit_decay_tau,
            xoffset=fit_par.dec_fit_latency,
            yoffset=fit_par.dec_fit_yoffset,
            rise_power=1,
        )
        self.dec_plots[0].plot(dec.time_values, fit, pen=(0, 255, 0, 100))