def responses(expt, pre, post, thresh, filter=None):
    key = (expt.nwb_file, pre, post)
    result_cache = load_cache(result_cache_file)
    if key in result_cache:
        res = result_cache[key]
        if 'avg_amp' not in res:
            return None, None, None
        avg_amp = res['avg_amp']
        avg_trace = Trace(data=res['data'], dt=res['dt'])
        n_sweeps = res['n_sweeps']
        return avg_amp, avg_trace, n_sweeps

    if filter is not None:
        responses, artifact = get_response(expt, pre, post, type='pulse')
        response_subset = response_filter(responses, freq_range=filter[0], holding_range=filter[1], pulse=True)
        if len(response_subset) > 0:
            avg_trace, avg_amp, _, _ = get_amplitude(response_subset)
        n_sweeps = len(response_subset)
    else:
        analyzer = DynamicsAnalyzer(expt, pre, post, align_to='spike')
        avg_amp, _, avg_trace, _, n_sweeps = analyzer.estimate_amplitude(plot=False)
        artifact = analyzer.cross_talk()
    if n_sweeps == 0 or artifact > thresh:
        result_cache[key] = {}
        ret = None, None, n_sweeps
    else:
        result_cache[key] = {'avg_amp': avg_amp, 'data': avg_trace.data, 'dt': avg_trace.dt, 'n_sweeps': n_sweeps}
        ret = avg_amp, avg_trace, n_sweeps

    data = pickle.dumps(result_cache)
    open(result_cache_file, 'wb').write(data)
    print (key)
    return ret
def train_response_plot(expt_list, name=None, summary_plots=[None, None], color=None):
    grand_train = [[], []]
    train_plots = pg.plot()
    train_plots.setLabels(left=('Vm', 'V'))
    tau =15e-3
    lp = 1000
    for expt in expt_list:
        for pre, post in expt.connections:
            if expt.cells[pre].cre_type == cre_type[0] and expt.cells[post].cre_type == cre_type[1]:
                print ('Processing experiment: %s' % (expt.nwb_file))

                train_responses, artifact = get_response(expt, pre, post, analysis_type='train')
                if artifact > 0.03e-3:
                    continue

                train_filter = response_filter(train_responses['responses'], freq_range=[50, 50], train=0, delta_t=250)
                pulse_offsets = response_filter(train_responses['pulse_offsets'], freq_range=[50, 50], train=0, delta_t=250)

                if len(train_filter[0]) > 5:
                    ind_avg = TraceList(train_filter[0]).mean()
                    rec_avg = TraceList(train_filter[1]).mean()
                    rec_avg.t0 = 0.3
                    grand_train[0].append(ind_avg)
                    grand_train[1].append(rec_avg)
                    train_plots.plot(ind_avg.time_values, ind_avg.data)
                    train_plots.plot(rec_avg.time_values, rec_avg.data)
                    app.processEvents()
    if len(grand_train[0]) != 0:
        print (name + ' n = %d' % len(grand_train[0]))
        ind_grand_mean = TraceList(grand_train[0]).mean()
        rec_grand_mean = TraceList(grand_train[1]).mean()
        ind_grand_mean_dec = bessel_filter(exp_deconvolve(ind_grand_mean, tau), lp)
        train_plots.addLegend()
        train_plots.plot(ind_grand_mean.time_values, ind_grand_mean.data, pen={'color': 'g', 'width': 3}, name=name)
        train_plots.plot(rec_grand_mean.time_values, rec_grand_mean.data, pen={'color': 'g', 'width': 3}, name=name)
        train_amps = train_amp([grand_train[0], grand_train[1]], pulse_offsets, '+')
        if ind_grand_mean is not None:
            train_plots = summary_plot_train(ind_grand_mean, plot=summary_plots[0], color=color,
                                             name=(legend + ' 50 Hz induction'))
            train_plots = summary_plot_train(rec_grand_mean, plot=summary_plots[0], color=color)
            train_plots2 = summary_plot_train(ind_grand_mean_dec, plot=summary_plots[1], color=color,
                                              name=(legend + ' 50 Hz induction'))
            return train_plots, train_plots2, train_amps
    else:
        print ("No Traces")
        return None
def save_fit_psp_test_set():
    """NOTE THIS CODE DOES NOT WORK BUT IS HERE FOR DOCUMENTATION PURPOSES SO 
    THAT WE CAN TRACE BACK HOW THE TEST DATA WAS CREATED IF NEEDED.
    Create a test set of data for testing the fit_psp function.  Uses Steph's 
    original first_puls_feature.py code to filter out error causing data.
    
    Example run statement
    python save save_fit_psp_test_set.py --organism mouse --connection ee
    
    Comment in the code that does the saving at the bottom
    """
    
    
    import pyqtgraph as pg
    import numpy as np
    import csv
    import sys
    import argparse
    from multipatch_analysis.experiment_list import cached_experiments
    from manuscript_figures import get_response, get_amplitude, response_filter, feature_anova, write_cache, trace_plot, \
        colors_human, colors_mouse, fail_rate, pulse_qc, feature_kw
    from synapse_comparison import load_cache, summary_plot_pulse
    from neuroanalysis.data import TraceList, Trace
    from neuroanalysis.ui.plot_grid import PlotGrid
    from multipatch_analysis.connection_detection import fit_psp
    from rep_connections import ee_connections, human_connections, no_include, all_connections, ie_connections, ii_connections, ei_connections
    from multipatch_analysis.synaptic_dynamics import DynamicsAnalyzer
    from scipy import stats
    import time
    import pandas as pd
    import json
    import os
    
    app = pg.mkQApp()
    pg.dbg()
    pg.setConfigOption('background', 'w')
    pg.setConfigOption('foreground', 'k')
    
    parser = argparse.ArgumentParser(description='Enter organism and type of connection you"d like to analyze ex: mouse ee (all mouse excitatory-'
                    'excitatory). Alternatively enter a cre-type connection ex: sim1-sim1')
    parser.add_argument('--organism', dest='organism', help='Select mouse or human')
    parser.add_argument('--connection', dest='connection', help='Specify connections to analyze')
    args = vars(parser.parse_args(sys.argv[1:]))
    
    all_expts = cached_experiments()
    manifest = {'Type': [], 'Connection': [], 'amp': [], 'latency': [],'rise':[], 'rise2080': [], 'rise1090': [], 'rise1080': [],
                'decay': [], 'nrmse': [], 'CV': []}
    fit_qc = {'nrmse': 8, 'decay': 499e-3}
    
    if args['organism'] == 'mouse':
        color_palette = colors_mouse
        calcium = 'high'
        age = '40-60'
        sweep_threshold = 3
        threshold = 0.03e-3
        connection = args['connection']
        if connection == 'ee':
            connection_types = ee_connections.keys()
        elif connection == 'ii':
            connection_types = ii_connections.keys()
        elif connection == 'ei':
            connection_types = ei_connections.keys()
        elif connection == 'ie':
            connection_types == ie_connections.keys()
        elif connection == 'all':
            connection_types = all_connections.keys()
        elif len(connection.split('-')) == 2:
            c_type = connection.split('-')
            if c_type[0] == '2/3':
                pre_type = ('2/3', 'unknown')
            else:
                pre_type = (None, c_type[0])
            if c_type[1] == '2/3':
                post_type = ('2/3', 'unknown')
            else:
                post_type = (None, c_type[0])
            connection_types = [(pre_type, post_type)]
    elif args['organism'] == 'human':
        color_palette = colors_human
        calcium = None
        age = None
        sweep_threshold = 5
        threshold = None
        connection = args['connection']
        if connection == 'ee':
            connection_types = human_connections.keys()
        else:
            c_type = connection.split('-')
            connection_types = [((c_type[0], 'unknown'), (c_type[1], 'unknown'))]
    
    plt = pg.plot()
    
    scale_offset = (-20, -20)
    scale_anchor = (0.4, 1)
    holding = [-65, -75]
    qc_plot = pg.plot()
    grand_response = {}
    expt_ids = {}
    feature_plot = None
    feature2_plot = PlotGrid()
    feature2_plot.set_shape(5,1)
    feature2_plot.show()
    feature3_plot = PlotGrid()
    feature3_plot.set_shape(1, 3)
    feature3_plot.show()
    amp_plot = pg.plot()
    synapse_plot = PlotGrid()
    synapse_plot.set_shape(len(connection_types), 1)
    synapse_plot.show()
    for c in range(len(connection_types)):
        cre_type = (connection_types[c][0][1], connection_types[c][1][1])
        target_layer = (connection_types[c][0][0], connection_types[c][1][0])
        conn_type = connection_types[c]
        expt_list = all_expts.select(cre_type=cre_type, target_layer=target_layer, calcium=calcium, age=age)
        color = color_palette[c]
        grand_response[conn_type[0]] = {'trace': [], 'amp': [], 'latency': [], 'rise': [], 'dist': [], 'decay':[], 'CV': [], 'amp_measured': []}
        expt_ids[conn_type[0]] = []
        synapse_plot[c, 0].addLegend()
        for expt in expt_list:
            for pre, post in expt.connections:
                if [expt.uid, pre, post] in no_include:
                    continue
                cre_check = expt.cells[pre].cre_type == cre_type[0] and expt.cells[post].cre_type == cre_type[1]
                layer_check = expt.cells[pre].target_layer == target_layer[0] and expt.cells[post].target_layer == target_layer[1]
                if cre_check is True and layer_check is True:
                    pulse_response, artifact = get_response(expt, pre, post, analysis_type='pulse')
                    if threshold is not None and artifact > threshold:
                        continue
                    response_subset, hold = response_filter(pulse_response, freq_range=[0, 50], holding_range=holding, pulse=True)
                    if len(response_subset) >= sweep_threshold:
                        qc_plot.clear()
                        qc_list = pulse_qc(response_subset, baseline=1.5, pulse=None, plot=qc_plot)
                        if len(qc_list) >= sweep_threshold:
                            avg_trace, avg_amp, amp_sign, peak_t = get_amplitude(qc_list)
    #                        if amp_sign is '-':
    #                            continue
    #                        #print ('%s, %0.0f' %((expt.uid, pre, post), hold, ))
    #                        all_amps = fail_rate(response_subset, '+', peak_t)
    #                        cv = np.std(all_amps)/np.mean(all_amps)
    #                        
    #                        # weight parts of the trace during fitting
                            dt = avg_trace.dt
                            weight = np.ones(len(avg_trace.data))*10.  #set everything to ten initially
                            weight[int(10e-3/dt):int(12e-3/dt)] = 0.   #area around stim artifact
                            weight[int(12e-3/dt):int(19e-3/dt)] = 30.  #area around steep PSP rise 
                            
                            # check if the test data dir is there and if not create it
                            test_data_dir='test_psp_fit'
                            if not os.path.isdir(test_data_dir):
                                os.mkdir(test_data_dir)
                                
                            save_dict={}
                            save_dict['input']={'data': avg_trace.data.tolist(),
                                                'dtype': str(avg_trace.data.dtype),
                                                'dt': float(avg_trace.dt),
                                                'amp_sign': amp_sign,
                                                'yoffset': 0, 
                                                'xoffset': 14e-3, 
                                                'avg_amp': float(avg_amp),
                                                'method': 'leastsq', 
                                                'stacked': False, 
                                                'rise_time_mult_factor': 10., 
                                                'weight': weight.tolist()} 
                            
                            # need to remake trace because different output is created
                            avg_trace_simple=Trace(data=np.array(save_dict['input']['data']), dt=save_dict['input']['dt']) # create Trace object
                            
                            psp_fits_original = fit_psp(avg_trace, 
                                               sign=save_dict['input']['amp_sign'], 
                                               yoffset=save_dict['input']['yoffset'], 
                                               xoffset=save_dict['input']['xoffset'], 
                                               amp=save_dict['input']['avg_amp'],
                                               method=save_dict['input']['method'], 
                                               stacked=save_dict['input']['stacked'], 
                                               rise_time_mult_factor=save_dict['input']['rise_time_mult_factor'], 
                                               fit_kws={'weights': save_dict['input']['weight']})  
    
                            psp_fits_simple = fit_psp(avg_trace_simple, 
                                               sign=save_dict['input']['amp_sign'], 
                                               yoffset=save_dict['input']['yoffset'], 
                                               xoffset=save_dict['input']['xoffset'], 
                                               amp=save_dict['input']['avg_amp'],
                                               method=save_dict['input']['method'], 
                                               stacked=save_dict['input']['stacked'], 
                                               rise_time_mult_factor=save_dict['input']['rise_time_mult_factor'], 
                                               fit_kws={'weights': save_dict['input']['weight']})  
                            print expt.uid, pre, post    
                            if psp_fits_original.nrmse()!=psp_fits_simple.nrmse():     
                                print '  the nrmse values dont match'
                                print '\toriginal', psp_fits_original.nrmse()
                                print '\tsimple', psp_fits_simple.nrmse()
Esempio n. 4
0
 expt_list = all_expts.select(cre_type=cre_type, target_layer=target_layer, calcium=calcium, age=age)
 color = color_palette[c]
 grand_response[type[0]] = {'trace': [], 'amp': [], 'latency': [], 'rise': [], 'dist': [], 'decay':[], 'CV': [], 'amp_measured': []}
 expt_ids[type[0]] = []
 synapse_plot[c, 0].addLegend()
 for expt in expt_list:
     for pre, post in expt.connections:
         if [expt.uid, pre, post] in no_include:
             continue
         cre_check = expt.cells[pre].cre_type == cre_type[0] and expt.cells[post].cre_type == cre_type[1]
         layer_check = expt.cells[pre].target_layer == target_layer[0] and expt.cells[post].target_layer == target_layer[1]
         if cre_check is True and layer_check is True:
             pulse_response, artifact = get_response(expt, pre, post, type='pulse')
             if threshold is not None and artifact > threshold:
                 continue
             response_subset, hold = response_filter(pulse_response, freq_range=[0, 50], holding_range=holding, pulse=True)
             if len(response_subset) >= sweep_threshold:
                 qc_plot.clear()
                 qc_list = pulse_qc(response_subset, baseline=1.5, pulse=None, plot=qc_plot)
                 if len(qc_list) >= sweep_threshold:
                     avg_trace, avg_amp, amp_sign, peak_t = get_amplitude(qc_list)
                     if amp_sign is '-':
                         continue
                     #print ('%s, %0.0f' %((expt.uid, pre, post), hold, ))
                     all_amps = fail_rate(response_subset, '+', peak_t)
                     cv = np.std(all_amps)/np.mean(all_amps)
                     
                     # weight parts of the trace during fitting
                     dt = avg_trace.dt
                     weight = np.ones(len(avg_trace.data))*10.  #set everything to ten initially
                     weight[int(10e-3/dt):int(12e-3/dt)] = 0.   #area around stim artifact
def first_pulse_plot(expt_list,
                     name=None,
                     summary_plot=None,
                     color=None,
                     scatter=0,
                     features=False):
    amp_plots = pg.plot()
    amp_plots.setLabels(left=('Vm', 'V'))
    grand_response = []
    avg_amps = {'amp': [], 'latency': [], 'rise': []}
    for expt in expt_list:
        if expt.connections is not None:
            for pre, post in expt.connections:
                if expt.cells[pre].cre_type == cre_type[0] and expt.cells[
                        post].cre_type == cre_type[1]:
                    all_responses, artifact = get_response(
                        expt, pre, post, analysis_type='pulse')
                    if artifact > 0.03e-3:
                        continue
                    filtered_responses = response_filter(
                        all_responses,
                        freq_range=[0, 50],
                        holding_range=[-68, -72],
                        pulse=True)
                    n_sweeps = len(filtered_responses)
                    if n_sweeps >= 10:
                        avg_trace, avg_amp, amp_sign, _ = get_amplitude(
                            filtered_responses)
                        if expt.cells[
                                pre].cre_type in EXCITATORY_CRE_TYPES and avg_amp < 0:
                            continue
                        elif expt.cells[
                                pre].cre_type in INHIBITORY_CRE_TYPES and avg_amp > 0:
                            continue
                        avg_trace.t0 = 0
                        avg_amps['amp'].append(avg_amp)
                        grand_response.append(avg_trace)
                        if features is True:
                            psp_fits = fit_psp(avg_trace,
                                               sign=amp_sign,
                                               yoffset=0,
                                               amp=avg_amp,
                                               method='leastsq',
                                               fit_kws={})
                            avg_amps['latency'].append(
                                psp_fits.best_values['xoffset'] - 10e-3)
                            avg_amps['rise'].append(
                                psp_fits.best_values['rise_time'])

                        current_connection_HS = post, pre
                        if len(expt.connections) > 1 and args.recip is True:
                            for i, x in enumerate(expt.connections):
                                if x == current_connection_HS:  # determine if a reciprocal connection
                                    amp_plots.plot(avg_trace.time_values,
                                                   avg_trace.data,
                                                   pen={
                                                       'color': 'r',
                                                       'width': 1
                                                   })
                                    break
                                elif x != current_connection_HS and i == len(
                                        expt.connections
                                ) - 1:  # reciprocal connection was not found
                                    amp_plots.plot(avg_trace.time_values,
                                                   avg_trace.data)
                        else:
                            amp_plots.plot(avg_trace.time_values,
                                           avg_trace.data)

                        app.processEvents()

    if len(grand_response) != 0:
        print(name + ' n = %d' % len(grand_response))
        grand_mean = TraceList(grand_response).mean()
        grand_amp = np.mean(np.array(avg_amps['amp']))
        grand_amp_sem = stats.sem(np.array(avg_amps['amp']))
        amp_plots.addLegend()
        amp_plots.plot(grand_mean.time_values,
                       grand_mean.data,
                       pen={
                           'color': 'g',
                           'width': 3
                       },
                       name=name)
        amp_plots.addLine(y=grand_amp, pen={'color': 'g'})
        if grand_mean is not None:
            print(legend + ' Grand mean amplitude = %f +- %f' %
                  (grand_amp, grand_amp_sem))
            if features is True:
                feature_list = (avg_amps['amp'], avg_amps['latency'],
                                avg_amps['rise'])
                labels = (['Vm', 'V'], ['t', 's'], ['t', 's'])
                titles = ('Amplitude', 'Latency', 'Rise time')
            else:
                feature_list = [avg_amps['amp']]
                labels = (['Vm', 'V'])
                titles = 'Amplitude'
            summary_plots = summary_plot_pulse(feature_list[0],
                                               labels=labels,
                                               titles=titles,
                                               i=scatter,
                                               grand_trace=grand_mean,
                                               plot=summary_plot,
                                               color=color,
                                               name=legend)
            return avg_amps, summary_plots
    else:
        print("No Traces")
        return avg_amps, None
def train_response_plot(expt_list,
                        name=None,
                        summary_plots=[None, None],
                        color=None):
    grand_train = [[], []]
    train_plots = pg.plot()
    train_plots.setLabels(left=('Vm', 'V'))
    tau = 15e-3
    lp = 1000
    for expt in expt_list:
        for pre, post in expt.connections:
            if expt.cells[pre].cre_type == cre_type[0] and expt.cells[
                    post].cre_type == cre_type[1]:
                print('Processing experiment: %s' % (expt.nwb_file))

                train_responses, artifact = get_response(expt,
                                                         pre,
                                                         post,
                                                         analysis_type='train')
                if artifact > 0.03e-3:
                    continue

                train_filter = response_filter(train_responses['responses'],
                                               freq_range=[50, 50],
                                               train=0,
                                               delta_t=250)
                pulse_offsets = response_filter(
                    train_responses['pulse_offsets'],
                    freq_range=[50, 50],
                    train=0,
                    delta_t=250)

                if len(train_filter[0]) > 5:
                    ind_avg = TraceList(train_filter[0]).mean()
                    rec_avg = TraceList(train_filter[1]).mean()
                    rec_avg.t0 = 0.3
                    grand_train[0].append(ind_avg)
                    grand_train[1].append(rec_avg)
                    train_plots.plot(ind_avg.time_values, ind_avg.data)
                    train_plots.plot(rec_avg.time_values, rec_avg.data)
                    app.processEvents()
    if len(grand_train[0]) != 0:
        print(name + ' n = %d' % len(grand_train[0]))
        ind_grand_mean = TraceList(grand_train[0]).mean()
        rec_grand_mean = TraceList(grand_train[1]).mean()
        ind_grand_mean_dec = bessel_filter(exp_deconvolve(ind_grand_mean, tau),
                                           lp)
        train_plots.addLegend()
        train_plots.plot(ind_grand_mean.time_values,
                         ind_grand_mean.data,
                         pen={
                             'color': 'g',
                             'width': 3
                         },
                         name=name)
        train_plots.plot(rec_grand_mean.time_values,
                         rec_grand_mean.data,
                         pen={
                             'color': 'g',
                             'width': 3
                         },
                         name=name)
        train_amps = train_amp([grand_train[0], grand_train[1]], pulse_offsets,
                               '+')
        if ind_grand_mean is not None:
            train_plots = summary_plot_train(ind_grand_mean,
                                             plot=summary_plots[0],
                                             color=color,
                                             name=(legend +
                                                   ' 50 Hz induction'))
            train_plots = summary_plot_train(rec_grand_mean,
                                             plot=summary_plots[0],
                                             color=color)
            train_plots2 = summary_plot_train(ind_grand_mean_dec,
                                              plot=summary_plots[1],
                                              color=color,
                                              name=(legend +
                                                    ' 50 Hz induction'))
            return train_plots, train_plots2, train_amps
    else:
        print("No Traces")
        return None
Esempio n. 7
0
     connection_type = tuple(connection_type.split('-'))
 expt_id, pre_cell, post_cell = connection_types[
     connection_type]  #all_connections[connection_type]
 expt = all_expts[expt_id]
 if expt.cells[pre_cell].cre_type in EXCITATORY_CRE_TYPES:
     holding = holding_e
     sign = '+'
 elif expt.cells[pre_cell].cre_type in INHIBITORY_CRE_TYPES:
     holding = holding_i
     sign = '-'
 pulse_response, artifact = get_response(expt,
                                         pre_cell,
                                         post_cell,
                                         type='pulse')
 sweep_list = response_filter(pulse_response,
                              freq_range=[0, 50],
                              holding_range=holding,
                              pulse=True)
 n_sweeps = len(sweep_list)
 if n_sweeps > sweep_threshold:
     qc_list = pulse_qc(sweep_list,
                        baseline=2.5,
                        pulse=None,
                        plot=grid[row, 1])
     qc_sweeps = len(qc_list)
     if qc_sweeps > sweep_threshold:
         avg_first_pulse = trace_avg(qc_list)
         avg_first_pulse.t0 = 0
         if plot_sweeps is True:
             for current_sweep in qc_list:
                 current_sweep.t0 = 0
                 bsub_trace = bsub(current_sweep)
Esempio n. 8
0
 for pre, post in expt.connections:
     if [expt.uid, pre, post] in no_include:
         continue
     cre_check = expt.cells[pre].cre_type == cre_type[0] and expt.cells[
         post].cre_type == cre_type[1]
     layer_check = expt.cells[pre].target_layer == target_layer[
         0] and expt.cells[post].target_layer == target_layer[1]
     if cre_check is True and layer_check is True:
         pulse_response, artifact = get_response(expt,
                                                 pre,
                                                 post,
                                                 type='pulse')
         if threshold is not None and artifact > threshold:
             continue
         response_subset = response_filter(pulse_response,
                                           freq_range=[0, 50],
                                           holding_range=[-68, -72],
                                           pulse=True)
         if len(response_subset) >= sweep_threshold:
             qc_plot.clear()
             qc_list = pulse_qc(response_subset,
                                baseline=2.5,
                                pulse=None,
                                plot=qc_plot)
             if len(qc_list) >= sweep_threshold:
                 avg_trace, avg_amp, amp_sign, peak_t = get_amplitude(
                     qc_list)
                 if amp_sign is '-':
                     continue
                 #grand_response[cre_type[0]]['fail_rate'].append(fail_rate(response_subset, '+', peak_t))
                 psp_fits = fit_psp(avg_trace,
                                    sign=amp_sign,
Esempio n. 9
0
def save_fit_psp_test_set():
    """NOTE THIS CODE DOES NOT WORK BUT IS HERE FOR DOCUMENTATION PURPOSES SO 
    THAT WE CAN TRACE BACK HOW THE TEST DATA WAS CREATED IF NEEDED.
    Create a test set of data for testing the fit_psp function.  Uses Steph's 
    original first_puls_feature.py code to filter out error causing data.
    
    Example run statement
    python save save_fit_psp_test_set.py --organism mouse --connection ee
    
    Comment in the code that does the saving at the bottom
    """

    import pyqtgraph as pg
    import numpy as np
    import csv
    import sys
    import argparse
    from multipatch_analysis.experiment_list import cached_experiments
    from manuscript_figures import get_response, get_amplitude, response_filter, feature_anova, write_cache, trace_plot, \
        colors_human, colors_mouse, fail_rate, pulse_qc, feature_kw
    from synapse_comparison import load_cache, summary_plot_pulse
    from neuroanalysis.data import TraceList, Trace
    from neuroanalysis.ui.plot_grid import PlotGrid
    from multipatch_analysis.connection_detection import fit_psp
    from rep_connections import ee_connections, human_connections, no_include, all_connections, ie_connections, ii_connections, ei_connections
    from multipatch_analysis.synaptic_dynamics import DynamicsAnalyzer
    from scipy import stats
    import time
    import pandas as pd
    import json
    import os

    app = pg.mkQApp()
    pg.dbg()
    pg.setConfigOption('background', 'w')
    pg.setConfigOption('foreground', 'k')

    parser = argparse.ArgumentParser(
        description=
        'Enter organism and type of connection you"d like to analyze ex: mouse ee (all mouse excitatory-'
        'excitatory). Alternatively enter a cre-type connection ex: sim1-sim1')
    parser.add_argument('--organism',
                        dest='organism',
                        help='Select mouse or human')
    parser.add_argument('--connection',
                        dest='connection',
                        help='Specify connections to analyze')
    args = vars(parser.parse_args(sys.argv[1:]))

    all_expts = cached_experiments()
    manifest = {
        'Type': [],
        'Connection': [],
        'amp': [],
        'latency': [],
        'rise': [],
        'rise2080': [],
        'rise1090': [],
        'rise1080': [],
        'decay': [],
        'nrmse': [],
        'CV': []
    }
    fit_qc = {'nrmse': 8, 'decay': 499e-3}

    if args['organism'] == 'mouse':
        color_palette = colors_mouse
        calcium = 'high'
        age = '40-60'
        sweep_threshold = 3
        threshold = 0.03e-3
        connection = args['connection']
        if connection == 'ee':
            connection_types = ee_connections.keys()
        elif connection == 'ii':
            connection_types = ii_connections.keys()
        elif connection == 'ei':
            connection_types = ei_connections.keys()
        elif connection == 'ie':
            connection_types == ie_connections.keys()
        elif connection == 'all':
            connection_types = all_connections.keys()
        elif len(connection.split('-')) == 2:
            c_type = connection.split('-')
            if c_type[0] == '2/3':
                pre_type = ('2/3', 'unknown')
            else:
                pre_type = (None, c_type[0])
            if c_type[1] == '2/3':
                post_type = ('2/3', 'unknown')
            else:
                post_type = (None, c_type[0])
            connection_types = [(pre_type, post_type)]
    elif args['organism'] == 'human':
        color_palette = colors_human
        calcium = None
        age = None
        sweep_threshold = 5
        threshold = None
        connection = args['connection']
        if connection == 'ee':
            connection_types = human_connections.keys()
        else:
            c_type = connection.split('-')
            connection_types = [((c_type[0], 'unknown'), (c_type[1],
                                                          'unknown'))]

    plt = pg.plot()

    scale_offset = (-20, -20)
    scale_anchor = (0.4, 1)
    holding = [-65, -75]
    qc_plot = pg.plot()
    grand_response = {}
    expt_ids = {}
    feature_plot = None
    feature2_plot = PlotGrid()
    feature2_plot.set_shape(5, 1)
    feature2_plot.show()
    feature3_plot = PlotGrid()
    feature3_plot.set_shape(1, 3)
    feature3_plot.show()
    amp_plot = pg.plot()
    synapse_plot = PlotGrid()
    synapse_plot.set_shape(len(connection_types), 1)
    synapse_plot.show()
    for c in range(len(connection_types)):
        cre_type = (connection_types[c][0][1], connection_types[c][1][1])
        target_layer = (connection_types[c][0][0], connection_types[c][1][0])
        conn_type = connection_types[c]
        expt_list = all_expts.select(cre_type=cre_type,
                                     target_layer=target_layer,
                                     calcium=calcium,
                                     age=age)
        color = color_palette[c]
        grand_response[conn_type[0]] = {
            'trace': [],
            'amp': [],
            'latency': [],
            'rise': [],
            'dist': [],
            'decay': [],
            'CV': [],
            'amp_measured': []
        }
        expt_ids[conn_type[0]] = []
        synapse_plot[c, 0].addLegend()
        for expt in expt_list:
            for pre, post in expt.connections:
                if [expt.uid, pre, post] in no_include:
                    continue
                cre_check = expt.cells[pre].cre_type == cre_type[
                    0] and expt.cells[post].cre_type == cre_type[1]
                layer_check = expt.cells[pre].target_layer == target_layer[
                    0] and expt.cells[post].target_layer == target_layer[1]
                if cre_check is True and layer_check is True:
                    pulse_response, artifact = get_response(
                        expt, pre, post, analysis_type='pulse')
                    if threshold is not None and artifact > threshold:
                        continue
                    response_subset, hold = response_filter(
                        pulse_response,
                        freq_range=[0, 50],
                        holding_range=holding,
                        pulse=True)
                    if len(response_subset) >= sweep_threshold:
                        qc_plot.clear()
                        qc_list = pulse_qc(response_subset,
                                           baseline=1.5,
                                           pulse=None,
                                           plot=qc_plot)
                        if len(qc_list) >= sweep_threshold:
                            avg_trace, avg_amp, amp_sign, peak_t = get_amplitude(
                                qc_list)
                            #                        if amp_sign is '-':
                            #                            continue
                            #                        #print ('%s, %0.0f' %((expt.uid, pre, post), hold, ))
                            #                        all_amps = fail_rate(response_subset, '+', peak_t)
                            #                        cv = np.std(all_amps)/np.mean(all_amps)
                            #
                            #                        # weight parts of the trace during fitting
                            dt = avg_trace.dt
                            weight = np.ones(
                                len(avg_trace.data
                                    )) * 10.  #set everything to ten initially
                            weight[int(10e-3 / dt):int(
                                12e-3 / dt)] = 0.  #area around stim artifact
                            weight[int(12e-3 / dt):int(
                                19e-3 / dt)] = 30.  #area around steep PSP rise

                            # check if the test data dir is there and if not create it
                            test_data_dir = 'test_psp_fit'
                            if not os.path.isdir(test_data_dir):
                                os.mkdir(test_data_dir)

                            save_dict = {}
                            save_dict['input'] = {
                                'data': avg_trace.data.tolist(),
                                'dtype': str(avg_trace.data.dtype),
                                'dt': float(avg_trace.dt),
                                'amp_sign': amp_sign,
                                'yoffset': 0,
                                'xoffset': 14e-3,
                                'avg_amp': float(avg_amp),
                                'method': 'leastsq',
                                'stacked': False,
                                'rise_time_mult_factor': 10.,
                                'weight': weight.tolist()
                            }

                            # need to remake trace because different output is created
                            avg_trace_simple = Trace(
                                data=np.array(save_dict['input']['data']),
                                dt=save_dict['input']
                                ['dt'])  # create Trace object

                            psp_fits_original = fit_psp(
                                avg_trace,
                                sign=save_dict['input']['amp_sign'],
                                yoffset=save_dict['input']['yoffset'],
                                xoffset=save_dict['input']['xoffset'],
                                amp=save_dict['input']['avg_amp'],
                                method=save_dict['input']['method'],
                                stacked=save_dict['input']['stacked'],
                                rise_time_mult_factor=save_dict['input']
                                ['rise_time_mult_factor'],
                                fit_kws={
                                    'weights': save_dict['input']['weight']
                                })

                            psp_fits_simple = fit_psp(
                                avg_trace_simple,
                                sign=save_dict['input']['amp_sign'],
                                yoffset=save_dict['input']['yoffset'],
                                xoffset=save_dict['input']['xoffset'],
                                amp=save_dict['input']['avg_amp'],
                                method=save_dict['input']['method'],
                                stacked=save_dict['input']['stacked'],
                                rise_time_mult_factor=save_dict['input']
                                ['rise_time_mult_factor'],
                                fit_kws={
                                    'weights': save_dict['input']['weight']
                                })
                            print expt.uid, pre, post
                            if psp_fits_original.nrmse(
                            ) != psp_fits_simple.nrmse():
                                print '  the nrmse values dont match'
                                print '\toriginal', psp_fits_original.nrmse()
                                print '\tsimple', psp_fits_simple.nrmse()
def first_pulse_plot(expt_list, name=None, summary_plot=None, color=None, scatter=0, features=False):
    amp_plots = pg.plot()
    amp_plots.setLabels(left=('Vm', 'V'))
    grand_response = []
    avg_amps = {'amp': [], 'latency': [], 'rise': []}
    for expt in expt_list:
        if expt.connections is not None:
            for pre, post in expt.connections:
                if expt.cells[pre].cre_type == cre_type[0] and expt.cells[post].cre_type == cre_type[1]:
                    all_responses, artifact = get_response(expt, pre, post, analysis_type='pulse')
                    if artifact > 0.03e-3:
                        continue
                    filtered_responses = response_filter(all_responses, freq_range=[0, 50], holding_range=[-68, -72], pulse=True)
                    n_sweeps = len(filtered_responses)
                    if n_sweeps >= 10:
                        avg_trace, avg_amp, amp_sign, _ = get_amplitude(filtered_responses)
                        if expt.cells[pre].cre_type in EXCITATORY_CRE_TYPES and avg_amp < 0:
                            continue
                        elif expt.cells[pre].cre_type in INHIBITORY_CRE_TYPES and avg_amp > 0:
                            continue
                        avg_trace.t0 = 0
                        avg_amps['amp'].append(avg_amp)
                        grand_response.append(avg_trace)
                        if features is True:
                            psp_fits = fit_psp(avg_trace, sign=amp_sign, yoffset=0, amp=avg_amp, method='leastsq',
                                               fit_kws={})
                            avg_amps['latency'].append(psp_fits.best_values['xoffset'] - 10e-3)
                            avg_amps['rise'].append(psp_fits.best_values['rise_time'])

                        current_connection_HS = post, pre
                        if len(expt.connections) > 1 and args.recip is True:
                            for i,x in enumerate(expt.connections):
                                if x == current_connection_HS:  # determine if a reciprocal connection
                                    amp_plots.plot(avg_trace.time_values, avg_trace.data, pen={'color': 'r', 'width': 1})
                                    break
                                elif x != current_connection_HS and i == len(expt.connections) - 1:  # reciprocal connection was not found
                                    amp_plots.plot(avg_trace.time_values, avg_trace.data)
                        else:
                            amp_plots.plot(avg_trace.time_values, avg_trace.data)

                        app.processEvents()

    if len(grand_response) != 0:
        print(name + ' n = %d' % len(grand_response))
        grand_mean = TraceList(grand_response).mean()
        grand_amp = np.mean(np.array(avg_amps['amp']))
        grand_amp_sem = stats.sem(np.array(avg_amps['amp']))
        amp_plots.addLegend()
        amp_plots.plot(grand_mean.time_values, grand_mean.data, pen={'color': 'g', 'width': 3}, name=name)
        amp_plots.addLine(y=grand_amp, pen={'color': 'g'})
        if grand_mean is not None:
            print(legend + ' Grand mean amplitude = %f +- %f' % (grand_amp, grand_amp_sem))
            if features is True:
                feature_list = (avg_amps['amp'], avg_amps['latency'], avg_amps['rise'])
                labels = (['Vm', 'V'], ['t', 's'], ['t', 's'])
                titles = ('Amplitude', 'Latency', 'Rise time')
            else:
                feature_list = [avg_amps['amp']]
                labels = (['Vm', 'V'])
                titles = 'Amplitude'
            summary_plots = summary_plot_pulse(feature_list[0], labels=labels, titles=titles, i=scatter,
                                               grand_trace=grand_mean, plot=summary_plot, color=color, name=legend)
            return avg_amps, summary_plots
    else:
        print ("No Traces")
        return avg_amps, None
maxYtrain = []
maxYdec = []
for row in range(len(connection_types)):
    connection_type = connection_types.keys()[row]
    if type(connection_type) is not tuple:
        connection_type = tuple(connection_type.split('-'))
    expt_id, pre_cell, post_cell = connection_types[connection_type] #all_connections[connection_type]
    expt = all_expts[expt_id]
    if expt.cells[pre_cell].cre_type in EXCITATORY_CRE_TYPES:
        holding = holding_e
        sign = '+'
    elif expt.cells[pre_cell].cre_type in INHIBITORY_CRE_TYPES:
        holding = holding_i
        sign = '-'
    pulse_response, artifact = get_response(expt, pre_cell, post_cell, analysis_type='pulse')
    sweep_list = response_filter(pulse_response, freq_range=[0, 50], holding_range=holding, pulse=True)
    n_sweeps = len(sweep_list[0])
    if n_sweeps > sweep_threshold:
        qc_list = pulse_qc(sweep_list, baseline=2, pulse=None, plot=pg.plot())
        qc_sweeps = len(qc_list)
        if qc_sweeps > sweep_threshold:
            avg_first_pulse = trace_avg(qc_list)
            avg_first_pulse.t0 = 0
            if plot_sweeps is True:
                for current_sweep in qc_list:
                    current_sweep.t0 = 0
                    bsub_trace = bsub(current_sweep)
                    trace_plot(bsub_trace, sweep_color, plot=grid[row, 0], x_range=[-2e-3, 27e-3])
            trace_plot(avg_first_pulse, avg_color, plot=grid[row, 0], x_range=[-2e-3, 27e-3])
            label = pg.LabelItem('%s, n = %d' % (connection_type, qc_sweeps))
            label.setParentItem(grid[row, 0].vb)