예제 #1
0
def populatemytables_core_paralel(arguments, runround):
    if runround == 1:
        ephysanal.SquarePulse().populate(**arguments)
        ephysanal.SweepFrameTimes().populate(**arguments)
    elif runround == 2:
        ephysanal.SquarePulseSeriesResistance().populate(**arguments)
        ephysanal.SweepSeriesResistance().populate(**arguments)
    elif runround == 3:
        ephysanal.SweepResponseCorrected().populate(**arguments)
    elif runround == 4:
        ephysanal.ActionPotential().populate(**arguments)
        ephysanal.ActionPotentialDetails().populate(**arguments)
예제 #2
0
def plot_cell_SN_ratio_APwise(roi_type = 'VolPy',v0_max = -35,holding_min = -600,frame_rate_min =300, frame_rate_max = 800 ,bin_num = 10 ):       
    #%% Show S/N ratios for each AP
    cmap = cm.get_cmap('jet')
# =============================================================================
#     bin_num = 10
#     holding_min = -600 #pA
#     v0_max = -35 #mV
#     roi_type = 'Spikepursuit'#'Spikepursuit'#'VolPy_denoised'#'SpikePursuit'#'VolPy_dexpF0'#'VolPy'#'SpikePursuit_dexpF0'#'VolPy_dexpF0'#''Spikepursuit'#'VolPy'#
# =============================================================================
    key = {'roi_type':roi_type}
    gtdata = pd.DataFrame((imaging_gt.GroundTruthROI()&key))
    cells = gtdata.groupby(['session', 'subject_id','cell_number','motion_correction_method','roi_type']).size().reset_index(name='Freq')
    snratio = list()
    v0s = list()
    holdings = list()
    rss = list()
    threshs =list()
    mintreshs = list()
    f0s_all = list()
    snratios_all = list()
    peakamplitudes_all = list()
    noise_all = list()
    for cell in cells.iterrows():
        cell = cell[1]
        key_cell = dict(cell)    
        del key_cell['Freq']
        snratios,f0,peakamplitudes,noises = (imaging.Movie()*imaging_gt.GroundTruthROI()*imaging_gt.ROIAPWave()*ephysanal.ActionPotentialDetails()&key_cell&'ap_real = 1'&'movie_frame_rate > {}'.format(frame_rate_min)&'movie_frame_rate < {}'.format(frame_rate_max)).fetch('apwave_snratio','apwave_f0','apwave_peak_amplitude','apwave_noise')
        #f0 =  (imaging_gt.GroundTruthROI()*imaging.ROI()&key_cell).fetch('roi_f0')
        
        sweep = (imaging_gt.GroundTruthROI()*imaging_gt.ROIAPWave()&key_cell).fetch('sweep_number')[0]
        thresh = (imaging_gt.GroundTruthROI()*imaging_gt.ROIAPWave()*ephysanal.ActionPotentialDetails()&key_cell&'ap_real = 1').fetch('ap_threshold')
        trace = (ephys_patch.SweepResponse()*imaging_gt.GroundTruthROI()&key_cell&'sweep_number = {}'.format(sweep)).fetch('response_trace')
        trace =trace[0]
        stimulus = (ephys_patch.SweepStimulus()*imaging_gt.GroundTruthROI()&key_cell&'sweep_number = {}'.format(sweep)).fetch('stimulus_trace')
        stimulus =stimulus[0]
        
        RS = (ephysanal.SweepSeriesResistance()*imaging_gt.GroundTruthROI()&key_cell&'sweep_number = {}'.format(sweep)).fetch('series_resistance')
        RS =RS[0]
        
        medianvoltage = np.median(trace)*1000
        holding = np.median(stimulus)*10**12
        #print(np.mean(snratios[:100]))    
        snratio.append(np.mean(snratios[:50]))
        v0s.append(medianvoltage)
        holdings.append(holding)
        rss.append(RS)
        threshs.append(thresh)
        mintreshs.append(np.min(thresh))
        f0s_all.append(f0)
        snratios_all.append(snratios)
        peakamplitudes_all.append(peakamplitudes)
        noise_all.append(noises)
        #plot_AP_waveforms(key_cell,AP_tlimits)
    
    #%%  for each AP
    
    fig=plt.figure()
    ax_sn_f0 = fig.add_axes([0,0,1,1])
    ax_sn_f0_binned = fig.add_axes([0,-1.2,1,1])
    ax_noise_f0 = fig.add_axes([2.6,0,1,1])
    ax_noise_f0_binned = fig.add_axes([2.6,-1.2,1,1])
    ax_peakampl_f0 = fig.add_axes([1.3,0,1,1])
    ax_peakampl_f0_binned = fig.add_axes([1.3,-1.2,1,1])
    for loopidx, (f0,snratio_now,noise_now,peakampl_now,cell_now) in enumerate(zip(f0s_all,snratios_all,noise_all,peakamplitudes_all,cells.iterrows())):
        if len(f0)>0:
            coloridx = loopidx/len(cells)
            cell_now = cell_now[1]
            label_now = 'Subject:{}'.format(cell_now['subject_id'])+' Cell:{}'.format(cell_now['cell_number'])
            ax_sn_f0.plot(f0,snratio_now,'o',ms=1, color = cmap(coloridx), label= label_now)
            ax_noise_f0.plot(f0,noise_now,'o',ms=1, color = cmap(coloridx), label= label_now)
            ax_peakampl_f0.plot(f0,peakampl_now,'o',ms=1, color = cmap(coloridx), label= label_now)
            lows = np.arange(np.min(f0),np.max(f0),(np.max(f0)-np.min(f0))/(bin_num+1))
            highs = lows + (np.max(f0)-np.min(f0))/(bin_num+1)
            mean_f0 = list()
            sd_f0 = list()
            mean_sn = list()
            sd_sn =list()
            mean_noise = list()
            sd_noise =list()
            mean_ampl = list()
            sd_ampl =list()
            for low,high in zip(lows,highs):
                idx = (f0 >= low) & (f0 < high)
                if len(idx)>10:
                    mean_f0.append(np.mean(f0[idx]))
                    sd_f0.append(np.std(f0[idx]))
                    mean_sn.append(np.mean(snratio_now[idx]))
                    sd_sn.append(np.std(snratio_now[idx]))
                    mean_noise.append(np.mean(noise_now[idx]))
                    sd_noise.append(np.std(noise_now[idx]))
                    mean_ampl.append(np.mean(peakampl_now[idx]))
                    sd_ampl.append(np.std(peakampl_now[idx]))
                    
            ax_sn_f0_binned.errorbar(mean_f0,mean_sn,sd_sn,sd_f0,'o-', color = cmap(coloridx), label= label_now)
            ax_noise_f0_binned.errorbar(mean_f0,mean_noise,sd_noise,sd_f0,'o-', color = cmap(coloridx), label= label_now)
            ax_peakampl_f0_binned.errorbar(mean_f0,mean_ampl,sd_ampl,sd_f0,'o-', color = cmap(coloridx), label= label_now)
                    
    ax_sn_f0.set_xlabel('F0')
    ax_sn_f0.set_ylabel('S/N ratio')
    ax_sn_f0_binned.set_xlabel('F0')
    ax_sn_f0_binned.set_ylabel('S/N ratio')
    
    #ax_sn_f0_binned.legend()
    ax_sn_f0_binned.legend(loc='upper center', bbox_to_anchor=(-.45, 1.5), shadow=True, ncol=1)

    
    
    ax_noise_f0.set_xlabel('F0')
    ax_noise_f0.set_ylabel('Noise (std(dF/F))')
    ax_noise_f0_binned.set_xlabel('F0')
    ax_noise_f0_binned.set_ylabel('Noise (std(dF/F))')
    
    ax_peakampl_f0.set_xlabel('F0')
    ax_peakampl_f0.set_ylabel('Peak amplitude (dF/F)')
    ax_peakampl_f0_binned.set_xlabel('F0')
    ax_peakampl_f0_binned.set_ylabel('Peak amplitude (dF/F)')
    
    #%%
    cells['SN']=snratio
    cells['V0']=v0s
    cells['holding']=holdings
    cells['RS']=np.asarray(rss,float)
    print(cells)
    cells = cells[cells['V0']<v0_max]
    cells = cells[cells['holding']>holding_min]
    print(cells)
    #% S/N ratio histogram
    fig=plt.figure()
    ax_hist = fig.add_axes([0,0,1,1])
    ax_hist.hist(cells['SN'].values)
    ax_hist.set_xlabel('S/N ratio of first 50 spikes')
    ax_hist.set_ylabel('# of cells')
    ax_hist.set_title(roi_type.replace('_',' '))
    ax_hist.set_xlim([0,15])
예제 #3
0
def plot_AP_waveforms(key,
                      AP_tlimits = [-.005,.01],
                      bin_step = .00001,
                      bin_size = .00025,
                      save_image = False):
    #%
    select_high_sn_APs = False
# =============================================================================
#     bin_step = .00001
#     bin_size = .00025
# =============================================================================
    #%
    tau_1_on = .64/1000
    tau_2_on = 4.1/1000
    tau_1_ratio_on =  .61
    tau_1_off = .78/1000
    tau_2_off = 3.9/1000
    tau_1_ratio_off = 55
    #%
    movie_numbers,sweep_numbers,apwavetimes,apwaves,famerates,snratio,apnums,ap_threshold = ((imaging_gt.GroundTruthROI()*imaging.Movie()*imaging_gt.ROIAPWave()*ephysanal.ActionPotentialDetails())&key&'ap_real = 1').fetch('movie_number','sweep_number','apwave_time','apwave_dff','movie_frame_rate','apwave_snratio','ap_num','ap_threshold')
    uniquemovienumbers = np.unique(movie_numbers)
    for movie_number in uniquemovienumbers:
        
        fig=plt.figure()
        ax_ephys=fig.add_axes([0,0,1,1])
        
        ax_raw=fig.add_axes([0,1.1,1,1])
        aps_now = movie_numbers == movie_number
        ax_bin=fig.add_axes([1.3,1.1,1,1])
        ax_e_convolved = fig.add_axes([1.3,0,1,1])
        
        
        ax_bin.set_title('{} ms binning'.format(bin_size*1000))
        if select_high_sn_APs :
            medsn = np.median(snratio[aps_now])
            aps_now = (movie_numbers == movie_number) & (snratio>medsn)
        
        framerate = famerates[aps_now][0]
        apwavetimes_conc = np.concatenate(apwavetimes[aps_now])
        apwaves_conc = np.concatenate(apwaves[aps_now])
        prev_sweep = None
        ephys_vs = list()
        for apwavetime,apwave,sweep_number,ap_num in zip(apwavetimes[aps_now],apwaves[aps_now],sweep_numbers[aps_now],apnums[aps_now]):
            wave_needed_idx = (apwavetime>=AP_tlimits[0]-1/framerate) & (apwavetime<=AP_tlimits[1]+1/framerate)
            ax_raw.plot(apwavetime[wave_needed_idx ]*1000,apwave[wave_needed_idx ])
            if prev_sweep != sweep_number:
                #%
                trace = (ephys_patch.SweepResponse()&key&'sweep_number = {}'.format(sweep_number)).fetch1('response_trace')*1000
                e_sr = (ephys_patch.SweepMetadata()&key&'sweep_number = {}'.format(sweep_number)).fetch1('sample_rate')
                stepback = int(np.abs(np.round(AP_tlimits[0]*e_sr)))
                stepforward = int(np.abs(np.round(AP_tlimits[1]*e_sr)))
                ephys_t = np.arange(-stepback,stepforward)/e_sr * 1000
                prev_sweep = sweep_number
                #%
            apmaxindex = (ephysanal.ActionPotential()&key & 'sweep_number = {}'.format(sweep_number) & 'ap_num = {}'.format(ap_num)).fetch1('ap_max_index')
            ephys_v = trace[apmaxindex-stepback:apmaxindex+stepforward]
            ephys_vs.append(ephys_v)
            ax_ephys.plot(ephys_t,ephys_v)
            
            #break
        #%
        mean_ephys_v = np.mean(np.asarray(ephys_vs),0)

#%
        t = np.arange(0,.01,1/e_sr)
        f_on = tau_1_ratio_on*np.exp(t/tau_1_on) + (1-tau_1_ratio_on)*np.exp(-t/tau_2_on)
        f_off = tau_1_ratio_off*np.exp(t[::-1]/tau_1_off) + (1-tau_1_ratio_off)*np.exp(-t[::-1]/tau_2_off)
        f_on = f_on/np.max(f_on)
        f_off = f_off/np.max(f_off)
        kernel = np.concatenate([f_on,np.zeros(len(f_off))])[::-1]
        kernel  = kernel /sum(kernel )
        
        trace_conv0 = np.convolve(np.concatenate([mean_ephys_v[::-1],mean_ephys_v,mean_ephys_v[::-1]]),kernel,mode = 'same') 
        trace_conv0 = trace_conv0[len(mean_ephys_v):2*len(mean_ephys_v)]
        
        kernel = np.ones(int(np.round(e_sr/framerate)))
        kernel  = kernel /sum(kernel )
        trace_conv = np.convolve(np.concatenate([trace_conv0[::-1],trace_conv0,trace_conv0[::-1]]),kernel,mode = 'same') 
        trace_conv = trace_conv[len(mean_ephys_v):2*len(mean_ephys_v)]

        bin_centers = np.arange(np.min(apwavetime),np.max(apwavetime),bin_step)
        
        bin_mean = list()
        for bin_center in bin_centers:
            bin_mean.append(np.mean(apwaves_conc[(apwavetimes_conc>bin_center-bin_size/2) & (apwavetimes_conc<bin_center+bin_size/2)]))
        ax_bin.plot(bin_centers*1000,np.asarray(bin_mean),'g-')
        ax_bin.invert_yaxis()
        ax_bin.set_xlim(np.asarray(AP_tlimits)*1000)
        ax_raw.invert_yaxis()
        ax_raw.autoscale(tight = True)
        ax_raw.set_xlim(np.asarray(AP_tlimits)*1000)
        ax_raw.set_ylabel('dF/F')
        ax_raw.set_title('subject: {} cell: {} movie: {} apnum: {}'.format(key['subject_id'],key['cell_number'],movie_number,sum(aps_now)))
        ax_ephys.set_xlim(np.asarray(AP_tlimits)*1000)
        ax_ephys.set_xlabel('ms')
        ax_ephys.set_ylabel('mV')
        ax_e_convolved.plot(ephys_t,mean_ephys_v,'k-',label = 'mean')
        ax_e_convolved.plot(ephys_t,trace_conv0,'g--',label = 'convolved mean')
        ax_e_convolved.plot(ephys_t,trace_conv,'g-',label = 'convolved & binned mean')
        ax_e_convolved.legend()
        ax_e_convolved.set_xlim(np.asarray(AP_tlimits)*1000)
        ax_e_convolved.set_xlabel('ms')
        plt.show()
        imaging_gt.ROIEphysCorrelation()
        if save_image:
            fig.savefig('./figures/APwaveforms_subject_{}_cell_{}_movie_{}.png'.format(key['subject_id'],key['cell_number'],movie_number), bbox_inches = 'tight')  
예제 #4
0
     step_back = int(integration_window*sr)
     if step_back<squarepulse['square_pulse_start_idx'] and np.abs(np.median(stim[:step_back]))<=max_baseline_current and np.median(trace[:step_back])<max_v0:
         RS = float((ephysanal.SweepSeriesResistance()&squarepulse).fetch1('series_resistance'))
         RS_residual = float((ephysanal.SweepSeriesResistance()&squarepulse).fetch1('series_resistance_residual'))
         baseline_v = np.median(trace[squarepulse['square_pulse_start_idx']-step_back:squarepulse['square_pulse_start_idx']])
         Rin_v = np.median(trace[squarepulse['square_pulse_end_idx']-step_back:squarepulse['square_pulse_end_idx']])
         dv = Rin_v-baseline_v
         di = squarepulse['square_pulse_amplitude']
         Rin = dv/di/1000000 - RS_residual
         Rins.append(Rin)
         RSs.append(RS)
         v0s.append(baseline_v*1000)
         #break
     
 if len(Rins)>1: 
     rs,threshold,baseline,hw,amplitude = (ephysanal.SweepSeriesResistance()*ephysanal.ActionPotential()*ephysanal.ActionPotentialDetails()&cell&'ap_real=1').fetch('series_resistance','ap_threshold','ap_baseline_value','ap_halfwidth','ap_amplitude')
     needed = (rs<AP_max_RS) & (baseline<AP_max_baseline)
     if sum(needed)>=100:
         ap_order = np.argsort(hw[needed])#[::-1]
         AP_ampl = np.median(amplitude[needed][ap_order][:APs_needed])
         AP_hw = np.median(hw[needed][ap_order][:APs_needed])
         AP_threshold = np.median(threshold[needed][ap_order][:APs_needed])
         RS = np.median(RSs)
         Rin = np.median(Rins)
         v0 = np.median(v0s)
         ephys_data['cell_dict'].append(cell)
         ephys_data['virus'].append(virus)
         ephys_data['expression_time'].append(expression_time)
         ephys_data['RS'].append(RS)
         ephys_data['Rin'].append(Rin)
         ephys_data['AP_amplitude'].append(AP_ampl)
예제 #5
0
#         data['figure_handle'].savefig('./figures/{}_cell_{}_roi_type_{}_long.png'.format(key_cell['subject_id'],key_cell['cell_number'],key_cell['roi_type']), bbox_inches = 'tight')
#         print(cell)
# =============================================================================
#%%    #%%
data = plot_ephys_ophys_trace(key_cell,time_to_plot=25,trace_window = 1,show_e_ap_peaks = True,show_o_ap_peaks = True)
#%%
session = 1
subject_id = 456462
cell_number = 5
roi_type = 'Spikepursuit'#'Spikepursuit'#'VolPy_denoised'#'SpikePursuit'#'VolPy_dexpF0'#'VolPy'#'SpikePursuit_dexpF0'#'VolPy_dexpF0'#''Spikepursuit'#'VolPy'#
key_cell = {'session':session,'subject_id':subject_id,'cell_number':cell_number,'roi_type':roi_type }

session_time, cell_recording_start = (experiment.Session()*ephys_patch.Cell()&key_cell).fetch1('session_time','cell_recording_start')
first_movie_start_time =  np.min(np.asarray(((imaging.Movie()*imaging_gt.GroundTruthROI())&key_cell).fetch('movie_start_time'),float))
first_movie_start_time_real = first_movie_start_time + session_time.total_seconds()
threshold,apmaxtime = (imaging_gt.ROIAPWave()*ephysanal.ActionPotential()*ephysanal.ActionPotentialDetails()&key_cell&'ap_real=1').fetch('ap_threshold','ap_max_time')
threshold=np.asarray(threshold,float)
apmaxtime=np.asarray(apmaxtime,float)

# =============================================================================
# session_time_to_plot = time_to_plot+first_movie_start_time  # time relative to session start
# cell_time_to_plot= session_time_to_plot + session_time.total_seconds() -cell_recording_start.total_seconds() # time relative to recording start
# =============================================================================

#%
time_to_plot = apmaxtime[np.argmin(threshold)]+cell_recording_start.total_seconds() - first_movie_start_time_real
data = plot_ephys_ophys_trace(key_cell,
                              time_to_plot=time_to_plot,
                              trace_window = .5,
                              show_stimulus = False,
                              show_e_ap_peaks = False,
예제 #6
0
 fig=plt.figure()
 ax_ephys = fig.add_axes([0,0,2,.8])
 ax_stim = fig.add_axes([0,-.3,2,.2])
 ax_ap1 = fig.add_axes([0,-.8,2,.4])
 ax_ap2 = ax_ap1.twinx()
 ax_snr = fig.add_axes([0,-1.3,2,.4])
 
 for t,response,stimulus,metadata_now in zip(sweep_time,sweep_response,sweep_stimulus,sweep_metadata):
     ax_ephys.plot(t,response,'k-')
     ax_stim.plot(t,stimulus,'k-')
     #%
     key_cell ={'subject_id': metadata_now['subject_id'],
                'session': metadata_now['session'],
                'cell_number':metadata_now['cell_number'],
                'sweep_number':metadata_now['sweep_number']}
     ap_max_time, ap_amplitude,ap_halfwidth,ap_threshold,snratio = (imaging_gt.GroundTruthROI()*imaging_gt.ROIAPWave()*ephysanal.ActionPotential()*ephysanal.ActionPotentialDetails()&key_cell&'ap_real=1'&'roi_type="VolPy"').fetch('ap_max_time','ap_amplitude','ap_halfwidth','ap_threshold','apwave_snratio')
     ap_max_time = np.asarray(ap_max_time,float)
     ax_ap2.plot(ap_max_time,ap_threshold-junction_potential,'ro')
     ax_ap1.plot(ap_max_time,ap_amplitude,'ko')
     ax_snr.plot(ap_max_time,snratio,'go')
     
     
     #%
 if dff is not None:
     ax_ophys = fig.add_axes([0,1,2,.8])
     prevminval = 0
     for dff_now,alpha_now in zip(dff_list,np.arange(1,1/(len(dff_list)+1),-1/(len(dff_list)+1))):
         dfftoplotnow = dff_now + prevminval
         ax_ophys.plot(frame_times,dfftoplotnow,'g-',alpha=alpha_now)
         prevminval = np.min(dfftoplotnow) -.01
     #ax_ophys.plot(frame_times,dff,'g-')
예제 #7
0
def plot_cell_SN_ratio_APwise(roi_type='VolPy',
                              v0_max=-35,
                              holding_min=-600,
                              frame_rate_min=300,
                              frame_rate_max=1800,
                              F0_min=50,
                              bin_num=10):

    #%% Show S/N ratios for each AP

    bin_num = 10
    holding_min = -600  #pA
    v0_max = -35  #mV
    roi_type = 'VolPy_raw'  #'Spikepursuit'#'VolPy_denoised'#'SpikePursuit'#'VolPy_dexpF0'#'VolPy'#'SpikePursuit_dexpF0'#'VolPy_dexpF0'#''Spikepursuit'#'VolPy'#
    F0_min = 50
    frame_rate_min = 200
    frame_rate_max = 800

    cmap = cm.get_cmap('jet')

    key = {'roi_type': roi_type}
    gtdata = pd.DataFrame((imaging_gt.GroundTruthROI() & key))
    cells = gtdata.groupby([
        'session', 'subject_id', 'cell_number', 'motion_correction_method',
        'roi_type'
    ]).size().reset_index(name='Freq')
    snratio = list()
    v0s = list()
    holdings = list()
    rss = list()
    threshs = list()
    mintreshs = list()
    f0s_all = list()
    snratios_all = list()
    peakamplitudes_all = list()
    noise_all = list()
    for cell in cells.iterrows():
        cell = cell[1]
        key_cell = dict(cell)
        del key_cell['Freq']
        snratios, f0, peakamplitudes, noises = (
            imaging.Movie() * imaging_gt.GroundTruthROI() *
            imaging_gt.ROIAPWave() * ephysanal.ActionPotentialDetails()
            & key_cell & 'ap_real = 1'
            & 'movie_frame_rate > {}'.format(frame_rate_min)
            & 'movie_frame_rate < {}'.format(frame_rate_max)
            & 'apwave_f0 > {}'.format(F0_min)).fetch('apwave_snratio',
                                                     'apwave_f0',
                                                     'apwave_peak_amplitude',
                                                     'apwave_noise')
        #f0 =  (imaging_gt.GroundTruthROI()*imaging.ROI()&key_cell).fetch('roi_f0')

        sweep = (imaging_gt.GroundTruthROI() * imaging_gt.ROIAPWave()
                 & key_cell).fetch('sweep_number')[0]
        thresh = (imaging_gt.GroundTruthROI() * imaging_gt.ROIAPWave() *
                  ephysanal.ActionPotentialDetails() & key_cell
                  & 'ap_real = 1').fetch('ap_threshold')
        trace = (ephys_patch.SweepResponse() * imaging_gt.GroundTruthROI()
                 & key_cell
                 & 'sweep_number = {}'.format(sweep)).fetch('response_trace')
        trace = trace[0]
        stimulus = (ephys_patch.SweepStimulus() * imaging_gt.GroundTruthROI()
                    & key_cell &
                    'sweep_number = {}'.format(sweep)).fetch('stimulus_trace')
        stimulus = stimulus[0]

        RS = (ephysanal.SweepSeriesResistance() * imaging_gt.GroundTruthROI()
              & key_cell
              & 'sweep_number = {}'.format(sweep)).fetch('series_resistance')
        RS = RS[0]

        medianvoltage = np.median(trace) * 1000
        holding = np.median(stimulus) * 10**12
        #print(np.mean(snratios[:100]))
        snratio.append(np.mean(snratios[:50]))
        v0s.append(medianvoltage)
        holdings.append(holding)
        rss.append(RS)
        threshs.append(thresh)
        mintreshs.append(np.min(thresh))
        f0s_all.append(f0)
        snratios_all.append(snratios)
        peakamplitudes_all.append(peakamplitudes)
        noise_all.append(noises)
        #plot_AP_waveforms(key_cell,AP_tlimits)
        #%%
# =============================================================================
#     virus_list=list()
#     subject_ids = list()
#     for cell in cells.iterrows():
#         cell = cell[1]
#         key_cell = dict(cell)
#         del key_cell['Freq']
#         virus_id = (lab.Surgery.VirusInjection()&'subject_id = {}'.format(key_cell['subject_id'])).fetch('virus_id')[0]
#         if virus_id == 238:
#             virus = 'Voltron 1'
#         elif virus_id == 240:
#             virus = 'Voltron 2'
#         virus_list.append(virus)
#         subject_ids.append(key_cell['subject_id'])
#    # order = np.argsort(virus_list)
#     order = np.lexsort((virus_list, subject_ids))
#     snratio = np.asarray(snratio)[order]
#     v0s = np.asarray(v0s)[order]
#     holdings = np.asarray(holdings)[order]
#     rss = np.asarray(rss)[order]
#     threshs =np.asarray(threshs)[order]
#     mintreshs = np.asarray(mintreshs)[order]
#     f0s_all = np.asarray(f0s_all)[order]
#     snratios_all = np.asarray(snratios_all)[order]
#     peakamplitudes_all = np.asarray(peakamplitudes_all)[order]
#     noise_all = np.asarray(noise_all)[order]
#     virus_list = np.asarray(virus_list)[order]
#     cells = cells.set_index(order, append=True).sort_index(level=1).reset_index(1, drop=True)
# =============================================================================
#%%
    apnum = 50
    fig = plt.figure(figsize=[10, 10])
    ax_exptime_f0 = fig.add_subplot(221)
    ax_exptime_dff = fig.add_subplot(222)
    ax_exptime_noise = fig.add_subplot(223)
    ax_exptime_snration = fig.add_subplot(224)
    for loopidx, (f0, snratio_now, noise_now, peakampl_now,
                  cell_now) in enumerate(
                      zip(f0s_all, snratios_all, noise_all, peakamplitudes_all,
                          cells.iterrows())):
        if len(
                f0
        ) > 0:  # and cell_now[1]['subject_id']==466774:# and cell_now[1]['cell_number']==1:
            coloridx = loopidx / len(cells)
            cell_now = cell_now[1]
            virus_id = (lab.Surgery.VirusInjection()
                        & 'subject_id = {}'.format(
                            cell_now['subject_id'])).fetch('virus_id')[0]
            if virus_id == 238:
                virus = 'Voltron 1'
            elif virus_id == 240:
                virus = 'Voltron 2'
            else:
                virus = '??'
            label_now = 'Subject:{}'.format(
                cell_now['subject_id']) + ' Cell:{} - {}'.format(
                    cell_now['cell_number'], virus)

            expression_time = np.diff(
                (lab.Surgery() & 'subject_id = {}'.format(
                    cell_now['subject_id'])).fetch('start_time'))[0].days
            ax_exptime_f0.plot(expression_time,
                               np.mean(f0[:apnum]),
                               'o',
                               ms=10,
                               color=cmap(coloridx),
                               label=label_now)
            ax_exptime_f0.errorbar(expression_time,
                                   np.mean(f0[:apnum]),
                                   np.std(f0[:apnum]),
                                   ecolor=cmap(coloridx))
            ax_exptime_f0.set_xlabel('expression time (days)')
            ax_exptime_f0.set_ylabel('F0 (pixel intensity)')

            ax_exptime_dff.plot(expression_time,
                                np.mean(peakampl_now[:apnum]),
                                'o',
                                ms=10,
                                color=cmap(coloridx),
                                label=label_now)
            ax_exptime_dff.errorbar(expression_time,
                                    np.mean(peakampl_now[:apnum]),
                                    np.std(peakampl_now[:apnum]),
                                    ecolor=cmap(coloridx))
            ax_exptime_dff.set_xlabel('expression time (days)')
            ax_exptime_dff.set_ylabel('AP peak amplitude (dF/F)')

            ax_exptime_noise.plot(expression_time,
                                  np.mean(noise_now[:apnum]),
                                  'o',
                                  ms=10,
                                  color=cmap(coloridx),
                                  label=label_now)
            ax_exptime_noise.errorbar(expression_time,
                                      np.mean(noise_now[:apnum]),
                                      np.std(noise_now[:apnum]),
                                      ecolor=cmap(coloridx))
            ax_exptime_noise.set_xlabel('expression time (days)')
            ax_exptime_noise.set_ylabel('noise (dF/F)')

            ax_exptime_snration.plot(expression_time,
                                     np.mean(snratio_now[:apnum]),
                                     'o',
                                     ms=10,
                                     color=cmap(coloridx),
                                     label=label_now)
            ax_exptime_snration.errorbar(expression_time,
                                         np.mean(snratio_now[:apnum]),
                                         np.std(snratio_now[:apnum]),
                                         ecolor=cmap(coloridx))
            ax_exptime_snration.set_xlabel('expression time (days)')
            ax_exptime_snration.set_ylabel('S/N ratio')
    #%%  for each AP

    fig = plt.figure()
    ax_sn_f0 = fig.add_axes([0, 0, 1, 1])
    ax_sn_f0_binned = fig.add_axes([0, -1.2, 1, 1])
    ax_noise_f0 = fig.add_axes([2.6, 0, 1, 1])
    ax_noise_f0_binned = fig.add_axes([2.6, -1.2, 1, 1])
    ax_peakampl_f0 = fig.add_axes([1.3, 0, 1, 1])
    ax_peakampl_f0_binned = fig.add_axes([1.3, -1.2, 1, 1])
    for loopidx, (f0, snratio_now, noise_now, peakampl_now,
                  cell_now) in enumerate(
                      zip(f0s_all, snratios_all, noise_all, peakamplitudes_all,
                          cells.iterrows())):
        if len(
                f0
        ) > 0:  # and cell_now[1]['subject_id']==466774:# and cell_now[1]['cell_number']==1:
            coloridx = loopidx / len(cells)
            cell_now = cell_now[1]
            virus_id = (lab.Surgery.VirusInjection()
                        & 'subject_id = {}'.format(
                            cell_now['subject_id'])).fetch('virus_id')[0]
            if virus_id == 238:
                virus = 'Voltron 1'
            elif virus_id == 240:
                virus = 'Voltron 2'
            else:
                virus = '??'
            label_now = 'Subject:{}'.format(
                cell_now['subject_id']) + ' Cell:{} - {}'.format(
                    cell_now['cell_number'], virus)
            ax_sn_f0.plot(f0,
                          snratio_now,
                          'o',
                          ms=1,
                          color=cmap(coloridx),
                          label=label_now)
            ax_noise_f0.plot(f0,
                             noise_now,
                             'o',
                             ms=1,
                             color=cmap(coloridx),
                             label=label_now)
            ax_peakampl_f0.plot(f0,
                                peakampl_now,
                                'o',
                                ms=1,
                                color=cmap(coloridx),
                                label=label_now)
            lows = np.arange(np.min(f0), np.max(f0),
                             (np.max(f0) - np.min(f0)) / (bin_num + 1))
            highs = lows + (np.max(f0) - np.min(f0)) / (bin_num + 1)
            mean_f0 = list()
            sd_f0 = list()
            mean_sn = list()
            sd_sn = list()
            mean_noise = list()
            sd_noise = list()
            mean_ampl = list()
            sd_ampl = list()
            for low, high in zip(lows, highs):
                idx = (f0 >= low) & (f0 < high)
                if len(idx) > 10:
                    mean_f0.append(np.mean(f0[idx]))
                    sd_f0.append(np.std(f0[idx]))
                    mean_sn.append(np.mean(snratio_now[idx]))
                    sd_sn.append(np.std(snratio_now[idx]))
                    mean_noise.append(np.mean(noise_now[idx]))
                    sd_noise.append(np.std(noise_now[idx]))
                    mean_ampl.append(np.mean(peakampl_now[idx]))
                    sd_ampl.append(np.std(peakampl_now[idx]))

            ax_sn_f0_binned.errorbar(mean_f0,
                                     mean_sn,
                                     sd_sn,
                                     sd_f0,
                                     'o-',
                                     color=cmap(coloridx),
                                     label=label_now)
            ax_noise_f0_binned.errorbar(mean_f0,
                                        mean_noise,
                                        sd_noise,
                                        sd_f0,
                                        'o-',
                                        color=cmap(coloridx),
                                        label=label_now)
            ax_peakampl_f0_binned.errorbar(mean_f0,
                                           mean_ampl,
                                           sd_ampl,
                                           sd_f0,
                                           'o-',
                                           color=cmap(coloridx),
                                           label=label_now)

    ax_sn_f0.set_xlabel('F0')
    ax_sn_f0.set_ylabel('S/N ratio')
    ax_sn_f0_binned.set_xlabel('F0')
    ax_sn_f0_binned.set_ylabel('S/N ratio')

    #ax_sn_f0_binned.legend()
    ax_sn_f0_binned.legend(loc='upper center',
                           bbox_to_anchor=(-.45, 1.5),
                           shadow=True,
                           ncol=1)

    ax_noise_f0.set_xlabel('F0')
    ax_noise_f0.set_ylabel('Noise (std(dF/F))')
    ax_noise_f0_binned.set_xlabel('F0')
    ax_noise_f0_binned.set_ylabel('Noise (std(dF/F))')

    ax_peakampl_f0.set_xlabel('F0')
    ax_peakampl_f0.set_ylabel('Peak amplitude (dF/F)')
    ax_peakampl_f0_binned.set_xlabel('F0')
    ax_peakampl_f0_binned.set_ylabel('Peak amplitude (dF/F)')

    #%%
    cells['SN'] = snratio
    cells['V0'] = v0s
    cells['holding'] = holdings
    cells['RS'] = np.asarray(rss, float)
    print(cells)
    cells = cells[cells['V0'] < v0_max]
    cells = cells[cells['holding'] > holding_min]
    print(cells)
    #% S/N ratio histogram
    fig = plt.figure()
    ax_hist = fig.add_axes([0, 0, 1, 1])
    ax_hist.hist(cells['SN'].values)
    ax_hist.set_xlabel('S/N ratio of first 50 spikes')
    ax_hist.set_ylabel('# of cells')
    ax_hist.set_title(roi_type.replace('_', ' '))
    ax_hist.set_xlim([0, 15])