Esempio n. 1
0
def test_psp_fitting():
    """Test psp_fit function against data from test directory.  Note that this 
    test is highly sensitive.  If this test fails check_psp_fitting can be 
    used to investigate whether the differences are substantial. Many things
    can change the output of the fit slightly that would not be considered a real
    difference from a scientific perspective.  i.e. numbers off by a precision 
    of e-6.  One should look though the plots created by check_psp_fitting if there
    is any question.  Unexpected things such as the order of the parameters fed 
    to the function can create completely different fits.
    """
    plotting=True # specifies whether to make plots of fitting results
        
    test_data_files=[os.path.join(test_data_dir,f) for f in os.listdir(test_data_dir)] #list of test files
    for file in sorted(test_data_files):
#    for file in ['test_psp_fit/1492546902.92_2_6stacked.json']: order of parameters affects this fit
        print 'file', file
        test_dict=json.load(open(file)) # load test data
        avg_trace=neuroanalysis.data.Trace(data=np.array(test_dict['input']['data']), dt=test_dict['input']['dt']) # create Trace object
        psp_fits = fit_psp(avg_trace, 
                           xoffset=(14e-3, -float('inf'), float('inf')),
                           weight=np.array(test_dict['input']['weight']),
                           sign=test_dict['input']['amp_sign'], 
                           stacked=test_dict['input']['stacked'] 
                            )                        
        
        assert test_dict['out']['best_values']==psp_fits.best_values, \
            "Best values don't match. Run check_psp_fitting for more information"

        assert test_dict['out']['best_fit']==psp_fits.best_fit.tolist(), \
            "Best fit traces don't match. Run check_psp_fitting for more information"

        assert test_dict['out']['nrmse']==float(psp_fits.nrmse()), \
           "Nrmse doesn't match. Run check_psp_fitting for more information"
Esempio n. 2
0
def fit_avg_pulse_response(pulse_response_list, latency_window, sign, init_params=None, ui=None):
    """Generate PSP fit parameters for a list of pulse responses, possibly correcting
    for crosstalk artifacts and gap junctional current during the presynaptic stimulus.
    
    Parameters
    ----------
    pulse_response_list : list
        A list of PulseResponse instances to be time-aligned, averaged, and fit.
    latency_window : (float, float)
        Beginning and end times of a window over which to search for a synaptic response,
        relative to the spike time.
    sign : int
        +1, -1, or 0 indicating the expected sign of the response (see neuroanalysis.fitting.fit_psp)
    
    Returns
    -------
    fit : lmfit ModelResult
        The resulting PSP fit
    average : TSeries
        The averaged pulse response data
    
    """
    # prof = Profiler(disabled=True, delayed=False)
    pair = pulse_response_list[0].pair
    clamp_mode = pulse_response_list[0].recording.patch_clamp_recording.clamp_mode

    # make a list of spike-aligned postsynaptic tseries
    tsl = PulseResponseList(pulse_response_list).post_tseries(align='spike', bsub=True)
    # prof('make tseries list')
    
    if len(tsl) == 0:
        return None, None
    
    # average all together
    average = tsl.mean()
    # prof('average')
        
    # start with even weighting
    weight = np.ones(len(average))
    
    # boost weight around PSP onset
    onset_start_idx = average.index_at(latency_window[0])
    onset_stop_idx = average.index_at(latency_window[1] + 4e-3) 
    weight[onset_start_idx:onset_stop_idx] = 3.0
    
    # decide whether to mask out crosstalk artifact
    pre_id = int(pair.pre_cell.electrode.ext_id)
    post_id = int(pair.post_cell.electrode.ext_id)
    if abs(pre_id - post_id) < 3:
        # nearby electrodes; mask out crosstalk
        pass
    # prof('weights')

    fit = fit_psp(average, search_window=latency_window, clamp_mode=clamp_mode, sign=sign, baseline_like_psp=True, init_params=init_params, fit_kws={'weights': weight})
    # prof('fit')
    
    return fit, average
Esempio n. 3
0
def check_psp_fitting():
    """Plots the results of the current fitting with the save fits and denotes 
    when there is a change.   
    """
    plotting=True # specifies whether to make plots of fitting results
        
    test_data_files=[os.path.join(test_data_dir,f) for f in os.listdir(test_data_dir)] #list of test files
    for file in sorted(test_data_files):
#    for file in ['test_psp_fit/1492546902.92_2_6stacked.json']: order of parameters affects this fit
        print 'file', file
        test_dict=json.load(open(file)) # load test data
        avg_trace=neuroanalysis.data.Trace(data=np.array(test_dict['input']['data']), dt=test_dict['input']['dt']) # create Trace object
        psp_fits = fit_psp(avg_trace, 
                           weight=np.array(test_dict['input']['weight']),
                           xoffset=(14e-3, -float('inf'), float('inf')),
                           sign=test_dict['input']['amp_sign'], 
                           stacked=test_dict['input']['stacked'] 
                            )                        
        
        change_flag=False
        if test_dict['out']['best_values']!=psp_fits.best_values:     
            print '  the best values dont match'
            print '\tsaved', test_dict['out']['best_values']
            print '\tobtained', psp_fits.best_values
            change_flag=True
            
        if test_dict['out']['best_fit']!=psp_fits.best_fit.tolist():
            print '  the best fit traces dont match'
            print '\tsaved', test_dict['out']['best_fit']
            print '\tobtained', psp_fits.best_fit.tolist()
            change_flag=True
        
        if test_dict['out']['nrmse']!=float(psp_fits.nrmse()):
            print '  the nrmse doesnt match'
            print '\tsaved', test_dict['out']['nrmse']
            print '\tobtained', float(psp_fits.nrmse())
            change_flag=True
            
        if plotting:
            import matplotlib.pylab as mplt
            fig=mplt.figure(figsize=(20,8))
            ax=fig.add_subplot(1,1,1)
            ax2=ax.twinx()
            ax.plot(avg_trace.time_values, psp_fits.data*1.e3, 'b', label='data')
            ax.plot(avg_trace.time_values, psp_fits.best_fit*1.e3, 'g', lw=5, label='current best fit')
            ax2.plot(avg_trace.time_values, test_dict['input']['weight'], 'r', label='weighting')
            if change_flag is True:
                ax.plot(avg_trace.time_values, np.array(test_dict['out']['best_fit'])*1.e3, 'k--', lw=5, label='original best fit')
                mplt.annotate('CHANGE', xy=(.5, .5), xycoords='figure fraction', fontsize=40)
            ax.legend()
            mplt.title(file + ', nrmse =' + str(psp_fits.nrmse()))
            mplt.show()
Esempio n. 4
0
def first_pulse_features(pair, pulse_responses, pulse_response_amps):

    avg_psp = TSeriesList(pulse_responses).mean()
    dt = avg_psp.dt
    avg_psp_baseline = float_mode(avg_psp.data[:int(10e-3 / dt)])
    avg_psp_bsub = avg_psp.copy(data=avg_psp.data - avg_psp_baseline)
    lower_bound = -float('inf')
    upper_bound = float('inf')
    xoffset = pair.synapse_prediction.ic_fit_xoffset
    if xoffset is None:
        xoffset = 14 * 10e-3
    synapse_type = pair.synapse_prediction.synapse_type
    if synapse_type == 'ex':
        amp_sign = '+'
    elif synapse_type == 'in':
        amp_sign = '-'
    else:
        raise Exception(
            'Synapse type is not defined, reconsider fitting this pair %s %d->%d'
            % (pair.expt_id, pair.pre_cell_id, pair.post_cell_id))

    weight = np.ones(len(
        avg_psp.data)) * 10.  # set everything to ten initially
    weight[int(10e-3 / dt):int(12e-3 / dt)] = 0.  # area around stim artifact
    weight[int(12e-3 / dt):int(19e-3 / dt)] = 30.  # area around steep PSP rise

    psp_fits = fit_psp(avg_psp,
                       xoffset=(xoffset, lower_bound, upper_bound),
                       yoffset=(avg_psp_baseline, lower_bound, upper_bound),
                       sign=amp_sign,
                       weight=weight)

    amp_cv = np.std(pulse_response_amps) / np.mean(pulse_response_amps)

    features = {
        'ic_fit_amp': psp_fits.best_values['amp'],
        'ic_fit_latency': psp_fits.best_values['xoffset'] - 10e-3,
        'ic_fit_rise_time': psp_fits.best_values['rise_time'],
        'ic_fit_decay_tau': psp_fits.best_values['decay_tau'],
        'ic_amp_cv': amp_cv,
        'avg_psp': avg_psp_bsub.data
    }
    #'ic_fit_NRMSE': psp_fits.nrmse()} TODO: nrmse not returned from psp_fits?

    return features
def first_pulse_features(pair, pulse_responses, pulse_response_amps):

    avg_psp = TraceList(pulse_responses).mean()
    dt = avg_psp.dt
    avg_psp_baseline = float_mode(avg_psp.data[:int(10e-3/dt)])
    avg_psp_bsub = avg_psp.copy(data=avg_psp.data - avg_psp_baseline)
    lower_bound = -float('inf')
    upper_bound = float('inf')
    xoffset = pair.connection_strength.ic_fit_xoffset
    if xoffset is None:
        xoffset = 14*10e-3
    synapse_type = pair.connection_strength.synapse_type
    if synapse_type == 'ex':
        amp_sign = '+'
    elif synapse_type == 'in':
        amp_sign = '-'
    else:
        raise Exception('Synapse type is not defined, reconsider fitting this pair %s %d->%d' %
                        (pair.expt_id, pair.pre_cell_id, pair.post_cell_id))

    weight = np.ones(len(avg_psp.data)) * 10.  # set everything to ten initially
    weight[int(10e-3 / dt):int(12e-3 / dt)] = 0.  # area around stim artifact
    weight[int(12e-3 / dt):int(19e-3 / dt)] = 30.  # area around steep PSP rise

    psp_fits = fit_psp(avg_psp,
                       xoffset=(xoffset, lower_bound, upper_bound),
                       yoffset=(avg_psp_baseline, lower_bound, upper_bound),
                       sign=amp_sign,
                       weight=weight)

    amp_cv = np.std(pulse_response_amps)/np.mean(pulse_response_amps)

    features = {'ic_fit_amp': psp_fits.best_values['amp'],
                'ic_fit_latency': psp_fits.best_values['xoffset'] - 10e-3,
                'ic_fit_rise_time': psp_fits.best_values['rise_time'],
                'ic_fit_decay_tau': psp_fits.best_values['decay_tau'],
                'ic_amp_cv': amp_cv,
                'avg_psp': avg_psp_bsub.data}
                #'ic_fit_NRMSE': psp_fits.nrmse()} TODO: nrmse not returned from psp_fits?

    return features
Esempio n. 6
0
        #             #print ('%s, %0.0f' %((expt.uid, pre, post), hold, ))
        #             all_amps = fail_rate(response_subset, '+', peak_t)
        #             cv = np.std(all_amps)/np.mean(all_amps)

        # weight parts of the trace during fitting
        dt = avg_trace.dt
        weight = np.ones(len(
            avg_trace.data)) * 10.  #set everything to ten initially
        weight[int(10e-3 / dt):int(12e-3 /
                                   dt)] = 0.  #area around stim artifact
        weight[int(12e-3 / dt):int(19e-3 /
                                   dt)] = 30.  #area around steep PSP rise

        psp_fits = fit_psp(
            avg_trace,
            xoffset=(14e-3, -float('inf'), float('inf')),
            sign=amp_sign,
            #                                           amp=avg_amp,
            weight=weight)
        plt.clear()
        plt.plot(avg_trace.time_values,
                 avg_trace.data,
                 title=str(
                     [psp_fits.best_values['xoffset'], expt.uid, pre, post]))
        plt.plot(avg_trace.time_values, psp_fits.eval(), pen='g')
        # avg_trace.t0 = -(psp_fits.best_values['xoffset'] - 10e-3)
        # distance = expt.cells[pre].distance(expt.cells[post])
        # grand_response[conn_type[0]]['CV'].append(cv)
        # latency = psp_fits.best_values['xoffset'] - 10e-3
        # rise = psp_fits.best_values['rise_time']
        # decay = psp_fits.best_values['decay_tau']
        # nrmse = psp_fits.nrmse()
        avg_trace, avg_amp, amp_sign, peak_t = get_amplitude(pulse_traces)
        #             if amp_sign is '-':
        #                 continue
        #             #print ('%s, %0.0f' %((expt.uid, pre, post), hold, ))
        #             all_amps = fail_rate(response_subset, '+', peak_t)
        #             cv = np.std(all_amps)/np.mean(all_amps)

                    # weight parts of the trace during fitting
        dt = avg_trace.dt
        weight = np.ones(len(avg_trace.data))*10.  #set everything to ten initially
        weight[int(10e-3/dt):int(12e-3/dt)] = 0.   #area around stim artifact
        weight[int(12e-3/dt):int(19e-3/dt)] = 30.  #area around steep PSP rise

        psp_fits = fit_psp(avg_trace,
                           xoffset=(14e-3, -float('inf'), float('inf')),
                           sign=amp_sign,
                           #                                           amp=avg_amp,
                           weight=weight)
        plt.clear()
        plt.plot(avg_trace.time_values, avg_trace.data, title=str([psp_fits.best_values['xoffset'], expt.uid, pre, post]))
        plt.plot(avg_trace.time_values, psp_fits.eval(), pen='g')
                    # avg_trace.t0 = -(psp_fits.best_values['xoffset'] - 10e-3)
                    # distance = expt.cells[pre].distance(expt.cells[post])
                    # grand_response[conn_type[0]]['CV'].append(cv)
                    # latency = psp_fits.best_values['xoffset'] - 10e-3
                    # rise = psp_fits.best_values['rise_time']
                    # decay = psp_fits.best_values['decay_tau']
                    # nrmse = psp_fits.nrmse()
                    # if nrmse < fit_qc['nrmse']:
                    #     grand_response[conn_type[0]]['latency'].append(psp_fits.best_values['xoffset'] - 10e-3)
                    #     max_x = np.argwhere(psp_fits.eval() == max(psp_fits.eval()))[0, 0]
Esempio n. 8
0
def fit_trace(waveform,
              excitation,
              clamp_mode='ic',
              weight=None,
              latency=None,
              latency_jitter=None):
    """
    Input
    -----
    waveform: TSeries Object
        contains data to be fit
    clamp_mode: string
        'vc' denotes voltage clamp
        'ic' denotes current clamp
    excitation: str
        'ex' or 'in' specifying excitation of synapse
    latency: float or None
        Amount of time that has passed in reference to the time of
        the pre-synaptic spike.  Note that this value has to be transformed in
        reference to the start of the data waveform being fit within
        this function.
    latency_jitter: None or float
        Amount of jitter to allow before and after the latency value. If
        *latency* is None, this value must be none.
    weight: numpy array
        Relative weighting of different sections of the voltage array
        for fitting.  If specified it must be the same length as the 
        waveform
    Note there is a 'time_before_spike' variable in this code that is
    not passed in and therefore must be global.  It specifies the
    the amount of data before the spike that is being considered in 
    the here and in the rest of the code.  Note that the value is positive,
    i.e. if we start looking at the trace 10 ms before the spike we use
    10e-3 not -10e-3.
    Returns
    -------
    self.ave_psp_fit: lmfit.model.ModelResult
        fit of the average psp waveform
    weight: numpy.ndarray
        the weight assigned to each index of the input waveform for fitting
    """
    # weighting
    if weight is None:
        weight = np.ones(len(waveform.data))  #set everything to ones

    # set fitting sign to positive or negative based on excitation and clamp state
    if (excitation == 'in') and (clamp_mode == 'ic'):
        sign = '-'
    elif (excitation == 'in') and (clamp_mode == 'vc'):
        sign = '+'
    elif (excitation == 'ex') and (clamp_mode == 'ic'):
        sign = '+'
    elif (excitation == 'ex') and (clamp_mode == 'vc'):
        sign = '-'
    elif excitation == 'any':
        sign = 'any'
    else:
        raise Exception('synaptic sign not defined')

    if latency is None and latency_jitter:
        raise Exception('latency_jitter cannot be specified if latency is not')
    if latency_jitter:
        if latency_jitter < .1e-3:
            raise Exception(
                'specified latency jitter less than .0e-3 may have implications for initial conditions'
            )
        if latency < latency_jitter:
            lower_bound = time_before_spike
        else:
            lower_bound = latency + time_before_spike - latency_jitter
        xoffset = ([
            lower_bound, latency + time_before_spike,
            latency + time_before_spike + latency_jitter - .1e-3
        ], lower_bound, latency + time_before_spike + latency_jitter)
    elif latency:
        xoffset = (latency + time_before_spike, 'fixed')
    else:
        #since these are spike aligned the psp should not happen before the spike that happens at pre_pad by definition
        xoffset = ([time_before_spike + 1e-3, time_before_spike + 4e-3],
                   time_before_spike, time_before_spike + 5e-3)

    fit = fit_psp(waveform,
                  clamp_mode=clamp_mode,
                  xoffset=xoffset,
                  sign=sign,
                  weight=weight)

    if clamp_mode == 'ic':
        scale_factor = 1.e3
        ylabel = 'voltage (mV)'
    if clamp_mode == 'vc':
        scale_factor = 1.e12
        ylabel = 'current (pA)'

    return fit
def fit_trace(waveform, excitation, clamp_mode='ic', weight=None, latency=None, latency_jitter=None):
    """
    Input
    -----
    waveform: Trace Object
        contains data to be fit
    clamp_mode: string
        'vc' denotes voltage clamp
        'ic' denotes current clamp
    excitation: str
        'ex' or 'in' specifying excitation of synapse
    latency: float or None
        Amount of time that has passed in reference to the time of
        the pre-synaptic spike.  Note that this value has to be transformed in
        reference to the start of the data waveform being fit within
        this function.
    latency_jitter: None or float
        Amount of jitter to allow before and after the latency value. If
        *latency* is None, this value must be none.
    weight: numpy array
        Relative weighting of different sections of the voltage array
        for fitting.  If specified it must be the same length as the 
        waveform
    Note there is a 'time_before_spike' variable in this code that is
    not passed in and therefore must be global.  It specifies the
    the amount of data before the spike that is being considered in 
    the here and in the rest of the code.  Note that the value is positive,
    i.e. if we start looking at the trace 10 ms before the spike we use
    10e-3 not -10e-3.
    Returns
    -------
    self.ave_psp_fit: lmfit.model.ModelResult
        fit of the average psp waveform
    weight: numpy.ndarray
        the weight assigned to each index of the input waveform for fitting
    """
    # weighting
    if weight is None:
        weight = np.ones(len(waveform.data))  #set everything to ones

    # set fitting sign to positive or negative based on excitation and clamp state
    if (excitation == 'in') and (clamp_mode == 'ic'):
        sign = '-'
    elif (excitation == 'in') and (clamp_mode == 'vc'):
        sign = '+'
    elif (excitation == 'ex') and (clamp_mode == 'ic'):
        sign = '+'
    elif (excitation == 'ex') and (clamp_mode == 'vc'):
        sign = '-'
    elif excitation == 'any':
        sign = 'any'
    else:
        raise Exception('synaptic sign not defined')
    

    if latency is None and latency_jitter:
        raise Exception('latency_jitter cannot be specified if latency is not')
    if latency_jitter:
        if latency_jitter < .1e-3:
            raise Exception('specified latency jitter less than .0e-3 may have implications for initial conditions')
        if latency < latency_jitter:
            lower_bound = time_before_spike
        else:
            lower_bound = latency + time_before_spike - latency_jitter
        xoffset=([lower_bound, latency+time_before_spike, latency+time_before_spike+latency_jitter-.1e-3], lower_bound, latency+time_before_spike+latency_jitter)  
    elif latency:
        xoffset=(latency+time_before_spike, 'fixed')
    else:
        #since these are spike aligned the psp should not happen before the spike that happens at pre_pad by definition
        xoffset=([time_before_spike+1e-3, time_before_spike+4e-3], time_before_spike, time_before_spike+5e-3)

    fit = fit_psp(waveform, 
                    clamp_mode=clamp_mode,
                    xoffset=xoffset,
                    sign=sign, 
                    weight=weight) 

    if clamp_mode == 'ic':
        scale_factor = 1.e3
        ylabel='voltage (mV)'
    if clamp_mode == 'vc':
        scale_factor = 1.e12
        ylabel='current (pA)'        

    return fit
Esempio n. 10
0
def measure_response(rec, baseline_rec):
    """Curve fit a single pulse response to measure its amplitude / kinetics.
    
    Uses the known latency and kinetics of the synapse to seed the fit.
    Optionally fit a baseline at the same time for noise measurement.
    """
    if rec.clamp_mode == 'ic':
        rise_time = rec.psp_rise_time
        decay_tau = rec.psp_decay_tau
    else:
        rise_time = rec.psc_rise_time
        decay_tau = rec.psc_decay_tau

    # make sure all parameters are available
    for v in [rec.spike_time, rec.latency, rise_time, decay_tau]:
        if v is None or rec.latency is None or not np.isfinite(v):
            return None, None

    data = TSeries(rec.data,
                   t0=rec.rec_start - rec.spike_time,
                   sample_rate=db.default_sample_rate)

    # decide whether/how to constrain the sign of the fit
    if rec.synapse_type == 'ex':
        sign = 1
    elif rec.synapse_type == 'in':
        if rec.baseline_potential > -60e-3:
            sign = -1
        else:
            sign = 0
    else:
        sign = 0
    if rec.clamp_mode == 'vc':
        sign = -sign

    # fit response region
    response_fit = fit_psp(
        data,
        search_window=rec.latency + np.array([-100e-6, 100e-6]),
        clamp_mode=rec.clamp_mode,
        sign=sign,
        baseline_like_psp=True,
        init_params={
            'rise_time': rise_time,
            'decay_tau': decay_tau
        },
        refine=False,
    )

    # fit baseline region
    if baseline_rec is None:
        baseline_fit = None
    else:
        baseline = TSeries(baseline_rec.data,
                           t0=data.t0,
                           sample_rate=db.default_sample_rate)
        baseline_fit = fit_psp(
            baseline,
            search_window=rec.latency + np.array([-100e-6, 100e-6]),
            clamp_mode=rec.clamp_mode,
            sign=sign,
            baseline_like_psp=True,
            init_params={
                'rise_time': rise_time,
                'decay_tau': decay_tau
            },
            refine=False,
        )

    return response_fit, baseline_fit
def measure_response(pr):
    """Curve fit a single pulse response to measure its amplitude / kinetics.
    
    Uses the known latency and kinetics of the synapse to seed the fit.
    Optionally fit a baseline at the same time for noise measurement.
    
    Parameters
    ----------
    pr : PulseResponse
    """
    syn = pr.pair.synapse
    pcr = pr.recording.patch_clamp_recording
    if pcr.clamp_mode == 'ic':
        rise_time = syn.psp_rise_time
        decay_tau = syn.psp_decay_tau
    else:
        rise_time = syn.psc_rise_time
        decay_tau = syn.psc_decay_tau

    # make sure all parameters are available
    for v in [
            pr.stim_pulse.first_spike_time, syn.latency, rise_time, decay_tau
    ]:
        if v is None or syn.latency is None or not np.isfinite(v):
            # print("bad:", pr.stim_pulse.first_spike_time, syn.latency, rise_time, decay_tau)
            return None, None

    data = pr.get_tseries('post', align_to='spike')

    # decide whether/how to constrain the sign of the fit
    if syn.synapse_type == 'ex':
        sign = 1
    elif syn.synapse_type == 'in':
        if pcr.baseline_potential > -60e-3:
            sign = -1
        else:
            sign = 0
    else:
        sign = 0
    if pcr.clamp_mode == 'vc':
        sign = -sign

    # fit response region
    response_fit = fit_psp(
        data,
        search_window=syn.latency + np.array([-100e-6, 100e-6]),
        clamp_mode=pcr.clamp_mode,
        sign=sign,
        baseline_like_psp=True,
        init_params={
            'rise_time': rise_time,
            'decay_tau': decay_tau
        },
        refine=False,
        decay_tau_bounds=('fixed', ),
        rise_time_bounds=('fixed', ),
    )

    # fit baseline region
    baseline = pr.get_tseries('baseline', align_to='spike')
    if baseline is None:
        baseline_fit = None
    else:
        baseline_fit = fit_psp(
            baseline,
            search_window=syn.latency + np.array([-100e-6, 100e-6]),
            clamp_mode=pcr.clamp_mode,
            sign=sign,
            baseline_like_psp=True,
            init_params={
                'rise_time': rise_time,
                'decay_tau': decay_tau
            },
            refine=False,
        )

    return response_fit, baseline_fit
Esempio n. 12
0
def save_fit_psp_test_set():
    """NOTE THIS CODE DOES NOT WORK BUT IS HERE FOR DOCUMENTATION PURPOSES SO 
    THAT WE CAN TRACE BACK HOW THE TEST DATA WAS CREATED IF NEEDED.
    Create a test set of data for testing the fit_psp function.  Uses Steph's 
    original first_puls_feature.py code to filter out error causing data.
    
    Example run statement
    python save save_fit_psp_test_set.py --organism mouse --connection ee
    
    Comment in the code that does the saving at the bottom
    """
    
    
    import pyqtgraph as pg
    import numpy as np
    import csv
    import sys
    import argparse
    from multipatch_analysis.experiment_list import cached_experiments
    from manuscript_figures import get_response, get_amplitude, response_filter, feature_anova, write_cache, trace_plot, \
        colors_human, colors_mouse, fail_rate, pulse_qc, feature_kw
    from synapse_comparison import load_cache, summary_plot_pulse
    from neuroanalysis.data import TraceList, Trace
    from neuroanalysis.ui.plot_grid import PlotGrid
    from multipatch_analysis.connection_detection import fit_psp
    from rep_connections import ee_connections, human_connections, no_include, all_connections, ie_connections, ii_connections, ei_connections
    from multipatch_analysis.synaptic_dynamics import DynamicsAnalyzer
    from scipy import stats
    import time
    import pandas as pd
    import json
    import os
    
    app = pg.mkQApp()
    pg.dbg()
    pg.setConfigOption('background', 'w')
    pg.setConfigOption('foreground', 'k')
    
    parser = argparse.ArgumentParser(description='Enter organism and type of connection you"d like to analyze ex: mouse ee (all mouse excitatory-'
                    'excitatory). Alternatively enter a cre-type connection ex: sim1-sim1')
    parser.add_argument('--organism', dest='organism', help='Select mouse or human')
    parser.add_argument('--connection', dest='connection', help='Specify connections to analyze')
    args = vars(parser.parse_args(sys.argv[1:]))
    
    all_expts = cached_experiments()
    manifest = {'Type': [], 'Connection': [], 'amp': [], 'latency': [],'rise':[], 'rise2080': [], 'rise1090': [], 'rise1080': [],
                'decay': [], 'nrmse': [], 'CV': []}
    fit_qc = {'nrmse': 8, 'decay': 499e-3}
    
    if args['organism'] == 'mouse':
        color_palette = colors_mouse
        calcium = 'high'
        age = '40-60'
        sweep_threshold = 3
        threshold = 0.03e-3
        connection = args['connection']
        if connection == 'ee':
            connection_types = ee_connections.keys()
        elif connection == 'ii':
            connection_types = ii_connections.keys()
        elif connection == 'ei':
            connection_types = ei_connections.keys()
        elif connection == 'ie':
            connection_types == ie_connections.keys()
        elif connection == 'all':
            connection_types = all_connections.keys()
        elif len(connection.split('-')) == 2:
            c_type = connection.split('-')
            if c_type[0] == '2/3':
                pre_type = ('2/3', 'unknown')
            else:
                pre_type = (None, c_type[0])
            if c_type[1] == '2/3':
                post_type = ('2/3', 'unknown')
            else:
                post_type = (None, c_type[0])
            connection_types = [(pre_type, post_type)]
    elif args['organism'] == 'human':
        color_palette = colors_human
        calcium = None
        age = None
        sweep_threshold = 5
        threshold = None
        connection = args['connection']
        if connection == 'ee':
            connection_types = human_connections.keys()
        else:
            c_type = connection.split('-')
            connection_types = [((c_type[0], 'unknown'), (c_type[1], 'unknown'))]
    
    plt = pg.plot()
    
    scale_offset = (-20, -20)
    scale_anchor = (0.4, 1)
    holding = [-65, -75]
    qc_plot = pg.plot()
    grand_response = {}
    expt_ids = {}
    feature_plot = None
    feature2_plot = PlotGrid()
    feature2_plot.set_shape(5,1)
    feature2_plot.show()
    feature3_plot = PlotGrid()
    feature3_plot.set_shape(1, 3)
    feature3_plot.show()
    amp_plot = pg.plot()
    synapse_plot = PlotGrid()
    synapse_plot.set_shape(len(connection_types), 1)
    synapse_plot.show()
    for c in range(len(connection_types)):
        cre_type = (connection_types[c][0][1], connection_types[c][1][1])
        target_layer = (connection_types[c][0][0], connection_types[c][1][0])
        conn_type = connection_types[c]
        expt_list = all_expts.select(cre_type=cre_type, target_layer=target_layer, calcium=calcium, age=age)
        color = color_palette[c]
        grand_response[conn_type[0]] = {'trace': [], 'amp': [], 'latency': [], 'rise': [], 'dist': [], 'decay':[], 'CV': [], 'amp_measured': []}
        expt_ids[conn_type[0]] = []
        synapse_plot[c, 0].addLegend()
        for expt in expt_list:
            for pre, post in expt.connections:
                if [expt.uid, pre, post] in no_include:
                    continue
                cre_check = expt.cells[pre].cre_type == cre_type[0] and expt.cells[post].cre_type == cre_type[1]
                layer_check = expt.cells[pre].target_layer == target_layer[0] and expt.cells[post].target_layer == target_layer[1]
                if cre_check is True and layer_check is True:
                    pulse_response, artifact = get_response(expt, pre, post, analysis_type='pulse')
                    if threshold is not None and artifact > threshold:
                        continue
                    response_subset, hold = response_filter(pulse_response, freq_range=[0, 50], holding_range=holding, pulse=True)
                    if len(response_subset) >= sweep_threshold:
                        qc_plot.clear()
                        qc_list = pulse_qc(response_subset, baseline=1.5, pulse=None, plot=qc_plot)
                        if len(qc_list) >= sweep_threshold:
                            avg_trace, avg_amp, amp_sign, peak_t = get_amplitude(qc_list)
    #                        if amp_sign is '-':
    #                            continue
    #                        #print ('%s, %0.0f' %((expt.uid, pre, post), hold, ))
    #                        all_amps = fail_rate(response_subset, '+', peak_t)
    #                        cv = np.std(all_amps)/np.mean(all_amps)
    #                        
    #                        # weight parts of the trace during fitting
                            dt = avg_trace.dt
                            weight = np.ones(len(avg_trace.data))*10.  #set everything to ten initially
                            weight[int(10e-3/dt):int(12e-3/dt)] = 0.   #area around stim artifact
                            weight[int(12e-3/dt):int(19e-3/dt)] = 30.  #area around steep PSP rise 
                            
                            # check if the test data dir is there and if not create it
                            test_data_dir='test_psp_fit'
                            if not os.path.isdir(test_data_dir):
                                os.mkdir(test_data_dir)
                                
                            save_dict={}
                            save_dict['input']={'data': avg_trace.data.tolist(),
                                                'dtype': str(avg_trace.data.dtype),
                                                'dt': float(avg_trace.dt),
                                                'amp_sign': amp_sign,
                                                'yoffset': 0, 
                                                'xoffset': 14e-3, 
                                                'avg_amp': float(avg_amp),
                                                'method': 'leastsq', 
                                                'stacked': False, 
                                                'rise_time_mult_factor': 10., 
                                                'weight': weight.tolist()} 
                            
                            # need to remake trace because different output is created
                            avg_trace_simple=Trace(data=np.array(save_dict['input']['data']), dt=save_dict['input']['dt']) # create Trace object
                            
                            psp_fits_original = fit_psp(avg_trace, 
                                               sign=save_dict['input']['amp_sign'], 
                                               yoffset=save_dict['input']['yoffset'], 
                                               xoffset=save_dict['input']['xoffset'], 
                                               amp=save_dict['input']['avg_amp'],
                                               method=save_dict['input']['method'], 
                                               stacked=save_dict['input']['stacked'], 
                                               rise_time_mult_factor=save_dict['input']['rise_time_mult_factor'], 
                                               fit_kws={'weights': save_dict['input']['weight']})  
    
                            psp_fits_simple = fit_psp(avg_trace_simple, 
                                               sign=save_dict['input']['amp_sign'], 
                                               yoffset=save_dict['input']['yoffset'], 
                                               xoffset=save_dict['input']['xoffset'], 
                                               amp=save_dict['input']['avg_amp'],
                                               method=save_dict['input']['method'], 
                                               stacked=save_dict['input']['stacked'], 
                                               rise_time_mult_factor=save_dict['input']['rise_time_mult_factor'], 
                                               fit_kws={'weights': save_dict['input']['weight']})  
                            print expt.uid, pre, post    
                            if psp_fits_original.nrmse()!=psp_fits_simple.nrmse():     
                                print '  the nrmse values dont match'
                                print '\toriginal', psp_fits_original.nrmse()
                                print '\tsimple', psp_fits_simple.nrmse()
def fit_trace(waveform,
              excitation,
              clamp_mode='ic',
              weight=None,
              latency=None,
              latency_jitter=None):
    """
    Input
    -----
    waveform: Trace Object
        contains data to be fit
    clamp_mode: string
        'vc' denotes voltage clamp
        'ic' denotes current clamp
    excitation: str
        'ex' or 'in' specifying excitation of synapse
    latency: float or None
        Amount of time that has passed in reference to the time of
        the pre-synaptic spike.  Note that this value has to be transformed in
        reference to the start of the data waveform being fit within
        this function.
    latency_jitter: None or float
        Amount of jitter to allow before and after the latency value. If
        *latency* is None, this value must be none.
    weight: numpy array
        Relative weighting of different sections of the voltage array
        for fitting.  If specified it must be the same length as the 
        waveform
    Note there is a 'time_before_spike' variable in this code that is
    not passed in and therefore must be global.  It specifies the
    the amount of data before the spike that is being considered in 
    the here and in the rest of the code.  Note that the value is positive,
    i.e. if we start looking at the trace 10 ms before the spike we use
    10e-3 not -10e-3.
    Returns
    -------
    self.ave_psp_fit: lmfit.model.ModelResult
        fit of the average psp waveform
    weight: numpy.ndarray
        the weight assigned to each index of the input waveform for fitting
    """
    # weighting
    if weight is None:
        weight = np.ones(len(waveform.data))  #set everything to ones

    # set fitting sign to positive or negative based on excitation and clamp state
    if (excitation == 'in') and (clamp_mode == 'ic'):
        sign = '-'
    elif (excitation == 'in') and (clamp_mode == 'vc'):
        sign = '+'
    elif (excitation == 'ex') and (clamp_mode == 'ic'):
        sign = '+'
    elif (excitation == 'ex') and (clamp_mode == 'vc'):
        sign = '-'
    elif excitation == 'any':
        sign = 'any'
    else:
        raise Exception('synaptic sign not defined')

    if latency is None and latency_jitter:
        raise Exception('latency_jitter cannot be specified if latency is not')
    if latency_jitter:
        if latency_jitter < .1e-3:
            raise Exception(
                'specified latency jitter less than .0e-3 may have implications for initial conditions'
            )
        if latency < latency_jitter:
            lower_bound = time_before_spike
        else:
            lower_bound = latency + time_before_spike - latency_jitter
        xoffset = ([
            lower_bound, latency + time_before_spike,
            latency + time_before_spike + latency_jitter - .1e-3
        ], lower_bound, latency + time_before_spike + latency_jitter)
    elif latency:
        xoffset = (latency + time_before_spike, 'fixed')
    else:
        #since these are spike aligned the psp should not happen before the spike that happens at pre_pad by definition
        xoffset = ([time_before_spike + 1e-3, time_before_spike + 4e-3],
                   time_before_spike, time_before_spike + 5e-3)

    fit = fit_psp(waveform,
                  clamp_mode=clamp_mode,
                  xoffset=xoffset,
                  sign=sign,
                  weight=weight)

    if clamp_mode == 'ic':
        scale_factor = 1.e3
        ylabel = 'voltage (mV)'
    if clamp_mode == 'vc':
        scale_factor = 1.e12
        ylabel = 'current (pA)'

    if False:
        plt.figure(figsize=(14, 10))
        ax1 = plt.subplot(1, 1, 1)
        ln1 = ax1.plot(waveform.time_values * 1.e3,
                       waveform.data * scale_factor,
                       'b',
                       label='data')
        if clamp_mode == 'ic':
            ln2=ax1.plot(waveform.time_values*1.e3, fit.best_fit*scale_factor, 'r', label='nrmse=%f \namp (mV)=%f \nlatency (ms)=%f \nrise time (ms)=%f \ndecay tau=%f' % \
                                            (fit.nrmse(), \
                                            fit.best_values['amp']*scale_factor, \
                                            (fit.best_values['xoffset']-time_before_spike)*1e3, \
                                            fit.best_values['rise_time']*1e3, \
                                            fit.best_values['decay_tau']))
        elif clamp_mode == 'vc':
            ln2=ax1.plot(waveform.time_values*1.e3, fit.best_fit*scale_factor, 'r', label='nrmse=%f \namp (pA)=%f \nlatency (ms)=%f \nrise time (ms)=%f \ndecay tau=%f' % \
                                            (fit.nrmse(), \
                                            fit.best_values['amp']*scale_factor, \
                                            (fit.best_values['xoffset']-time_before_spike)*1e3, \
                                            fit.best_values['rise_time']*1e3, \
                                            fit.best_values['decay_tau']))
        ax2 = ax1.twinx()
        ln3 = ax2.plot(waveform.time_values * 1.e3,
                       weight,
                       'k',
                       label='weight')
        ax1.set_ylabel(ylabel)
        ax2.set_ylabel('weight')
        ax1.set_xlabel('time (ms): spike happens at 10 ms')

        lines_plot = ln1 + ln2 + ln3
        label_plot = [l.get_label() for l in lines_plot]
        ax1.legend(lines_plot, label_plot)
        plt.show()

    return fit