def extrapolate_sweep_series_resistance():
    #%
    subject_ids = np.unique((ephysanal.SweepSeriesResistance()&'series_resistance is null').fetch('subject_id'))
    for subject_id in subject_ids:
        print('extrapolating sweep series resistance for subject {}'.format(subject_id))
        cell_numbers = np.unique((ephysanal.SweepSeriesResistance()&'series_resistance is null'&'subject_id = {}'.format(subject_id)).fetch('cell_number'))
        for cell_number in cell_numbers:
            #%
            key = {'subject_id':subject_id,
                   'cell_number':cell_number}
            data_old = pd.DataFrame(ephysanal.SweepSeriesResistance()&key)
            data = data_old.copy()
            nones = data['series_resistance'].values==None
            labels, nums = scipy.ndimage.label(nones)
            for num in range(1,nums+1):
                idxs_now = np.where(labels==num)[0]
                if idxs_now[0]==0:
                    data['series_resistance'][idxs_now]=data['series_resistance'][idxs_now[-1]+1]
                    data['series_resistance_bridged'][idxs_now]=data['series_resistance_bridged'][idxs_now[-1]+1]
                    data['series_resistance_residual'][idxs_now]=data['series_resistance_residual'][idxs_now[-1]+1]
                elif idxs_now[-1]==len(data)-1:
                    data['series_resistance'][idxs_now]=data['series_resistance'][idxs_now[0]-1]
                    data['series_resistance_bridged'][idxs_now]=data['series_resistance_bridged'][idxs_now[0]-1]
                    data['series_resistance_residual'][idxs_now]=data['series_resistance_residual'][idxs_now[0]-1]
                else:
                    rs_start = float(data['series_resistance'][idxs_now[0]-1])
                    rs_end = float(data['series_resistance'][idxs_now[-1]+1])
                    rss = np.arange(rs_start,rs_end,(rs_end-rs_start)/(len(idxs_now)+1))[:-1]
                    data['series_resistance'][idxs_now] = rss
                    data['series_resistance_bridged'][idxs_now] = data['series_resistance_bridged'][idxs_now[0]-1]
                    data['series_resistance_residual'][idxs_now] = data['series_resistance'][idxs_now] - data['series_resistance_bridged'][idxs_now].astype(float)
               #%
            for row_old,row_new in zip(data_old.iterrows(),data.iterrows()):
                row_old = dict(row_old[1])
                row_new = dict(row_new[1] )
                if row_old['series_resistance']  ==  None:
                    del row_old['series_resistance']
                    del row_old['series_resistance_bridged']
                    del row_old['series_resistance_residual']
                    dj.config['safemode'] = False
                    (ephysanal.SweepSeriesResistance()&row_old).delete()
                    dj.config['safemode'] = True  
                    ephysanal.SweepSeriesResistance().insert1(row_new, allow_direct_insert=True)
                    
               
            
                    
#%%
                    
                
            print(cell_number)
Beispiel #2
0
def populatemytables_core_paralel(arguments, runround):
    if runround == 1:
        ephysanal.SquarePulse().populate(**arguments)
        ephysanal.SweepFrameTimes().populate(**arguments)
    elif runround == 2:
        ephysanal.SquarePulseSeriesResistance().populate(**arguments)
        ephysanal.SweepSeriesResistance().populate(**arguments)
    elif runround == 3:
        ephysanal.SweepResponseCorrected().populate(**arguments)
    elif runround == 4:
        ephysanal.ActionPotential().populate(**arguments)
        ephysanal.ActionPotentialDetails().populate(**arguments)
Beispiel #3
0
def get_sweep(key_sweep,junction_potential = 13.5, downsampled_rate = 10000):
    key_sweep['cell_number'] = (imaging_gt.CellMovieCorrespondance()&key_sweep).fetch1('cell_number')
    #%
# =============================================================================
#     key_sweep = {'subject_id': 456462,'session' : 1, 'movie_number' : 0, 'cell_number' : 3,'sweep_number':28} 
#     junction_potential = 13.5 #mV
#     downsampled_rate = 10000 #Hz
# =============================================================================

    neutralizationenable,e_sr= (ephys_patch.SweepMetadata()&key_sweep).fetch1('neutralizationenable','sample_rate')
    try:
        uncompensatedRS =  float((ephysanal.SweepSeriesResistance()&key_sweep).fetch1('series_resistance_residual'))
    except:
        uncompensatedRS = 0
    v = (ephys_patch.SweepResponse()&key_sweep).fetch1('response_trace')
    i = (ephys_patch.SweepStimulus()&key_sweep).fetch1('stimulus_trace')
    tau_1_on =.1/1000
    t = np.arange(0,.001,1/e_sr)
    f_on = np.exp(t/tau_1_on) 
    f_on = f_on/np.max(f_on)
    kernel = np.concatenate([f_on,np.zeros(len(t))])[::-1]
    kernel  = kernel /sum(kernel )  
    i_conv = np.convolve(i,kernel,'same')
    v_comp = (v - i_conv*uncompensatedRS*10**6)*1000 - junction_potential
    i = i * 10**12
    
    sweep_start_time  = float((ephys_patch.Sweep()&key_sweep).fetch('sweep_start_time')) 
    trace_t = np.arange(len(v))/e_sr + sweep_start_time
    
    downsample_factor = int(np.round(e_sr/downsampled_rate))
    #%downsampling
    v_out = moving_average(v_comp, n=downsample_factor)
    v_out = v_out[int(downsample_factor/2)::downsample_factor]
    i_out = moving_average(i, n=downsample_factor)
    i_out = i_out[int(downsample_factor/2)::downsample_factor]
    t_out = moving_average(trace_t, n=downsample_factor)
    t_out = t_out[int(downsample_factor/2)::downsample_factor]
    
    return v_out, i_out, t_out
Beispiel #4
0
def plot_cell_SN_ratio_APwise(roi_type = 'VolPy',v0_max = -35,holding_min = -600,frame_rate_min =300, frame_rate_max = 800 ,bin_num = 10 ):       
    #%% Show S/N ratios for each AP
    cmap = cm.get_cmap('jet')
# =============================================================================
#     bin_num = 10
#     holding_min = -600 #pA
#     v0_max = -35 #mV
#     roi_type = 'Spikepursuit'#'Spikepursuit'#'VolPy_denoised'#'SpikePursuit'#'VolPy_dexpF0'#'VolPy'#'SpikePursuit_dexpF0'#'VolPy_dexpF0'#''Spikepursuit'#'VolPy'#
# =============================================================================
    key = {'roi_type':roi_type}
    gtdata = pd.DataFrame((imaging_gt.GroundTruthROI()&key))
    cells = gtdata.groupby(['session', 'subject_id','cell_number','motion_correction_method','roi_type']).size().reset_index(name='Freq')
    snratio = list()
    v0s = list()
    holdings = list()
    rss = list()
    threshs =list()
    mintreshs = list()
    f0s_all = list()
    snratios_all = list()
    peakamplitudes_all = list()
    noise_all = list()
    for cell in cells.iterrows():
        cell = cell[1]
        key_cell = dict(cell)    
        del key_cell['Freq']
        snratios,f0,peakamplitudes,noises = (imaging.Movie()*imaging_gt.GroundTruthROI()*imaging_gt.ROIAPWave()*ephysanal.ActionPotentialDetails()&key_cell&'ap_real = 1'&'movie_frame_rate > {}'.format(frame_rate_min)&'movie_frame_rate < {}'.format(frame_rate_max)).fetch('apwave_snratio','apwave_f0','apwave_peak_amplitude','apwave_noise')
        #f0 =  (imaging_gt.GroundTruthROI()*imaging.ROI()&key_cell).fetch('roi_f0')
        
        sweep = (imaging_gt.GroundTruthROI()*imaging_gt.ROIAPWave()&key_cell).fetch('sweep_number')[0]
        thresh = (imaging_gt.GroundTruthROI()*imaging_gt.ROIAPWave()*ephysanal.ActionPotentialDetails()&key_cell&'ap_real = 1').fetch('ap_threshold')
        trace = (ephys_patch.SweepResponse()*imaging_gt.GroundTruthROI()&key_cell&'sweep_number = {}'.format(sweep)).fetch('response_trace')
        trace =trace[0]
        stimulus = (ephys_patch.SweepStimulus()*imaging_gt.GroundTruthROI()&key_cell&'sweep_number = {}'.format(sweep)).fetch('stimulus_trace')
        stimulus =stimulus[0]
        
        RS = (ephysanal.SweepSeriesResistance()*imaging_gt.GroundTruthROI()&key_cell&'sweep_number = {}'.format(sweep)).fetch('series_resistance')
        RS =RS[0]
        
        medianvoltage = np.median(trace)*1000
        holding = np.median(stimulus)*10**12
        #print(np.mean(snratios[:100]))    
        snratio.append(np.mean(snratios[:50]))
        v0s.append(medianvoltage)
        holdings.append(holding)
        rss.append(RS)
        threshs.append(thresh)
        mintreshs.append(np.min(thresh))
        f0s_all.append(f0)
        snratios_all.append(snratios)
        peakamplitudes_all.append(peakamplitudes)
        noise_all.append(noises)
        #plot_AP_waveforms(key_cell,AP_tlimits)
    
    #%%  for each AP
    
    fig=plt.figure()
    ax_sn_f0 = fig.add_axes([0,0,1,1])
    ax_sn_f0_binned = fig.add_axes([0,-1.2,1,1])
    ax_noise_f0 = fig.add_axes([2.6,0,1,1])
    ax_noise_f0_binned = fig.add_axes([2.6,-1.2,1,1])
    ax_peakampl_f0 = fig.add_axes([1.3,0,1,1])
    ax_peakampl_f0_binned = fig.add_axes([1.3,-1.2,1,1])
    for loopidx, (f0,snratio_now,noise_now,peakampl_now,cell_now) in enumerate(zip(f0s_all,snratios_all,noise_all,peakamplitudes_all,cells.iterrows())):
        if len(f0)>0:
            coloridx = loopidx/len(cells)
            cell_now = cell_now[1]
            label_now = 'Subject:{}'.format(cell_now['subject_id'])+' Cell:{}'.format(cell_now['cell_number'])
            ax_sn_f0.plot(f0,snratio_now,'o',ms=1, color = cmap(coloridx), label= label_now)
            ax_noise_f0.plot(f0,noise_now,'o',ms=1, color = cmap(coloridx), label= label_now)
            ax_peakampl_f0.plot(f0,peakampl_now,'o',ms=1, color = cmap(coloridx), label= label_now)
            lows = np.arange(np.min(f0),np.max(f0),(np.max(f0)-np.min(f0))/(bin_num+1))
            highs = lows + (np.max(f0)-np.min(f0))/(bin_num+1)
            mean_f0 = list()
            sd_f0 = list()
            mean_sn = list()
            sd_sn =list()
            mean_noise = list()
            sd_noise =list()
            mean_ampl = list()
            sd_ampl =list()
            for low,high in zip(lows,highs):
                idx = (f0 >= low) & (f0 < high)
                if len(idx)>10:
                    mean_f0.append(np.mean(f0[idx]))
                    sd_f0.append(np.std(f0[idx]))
                    mean_sn.append(np.mean(snratio_now[idx]))
                    sd_sn.append(np.std(snratio_now[idx]))
                    mean_noise.append(np.mean(noise_now[idx]))
                    sd_noise.append(np.std(noise_now[idx]))
                    mean_ampl.append(np.mean(peakampl_now[idx]))
                    sd_ampl.append(np.std(peakampl_now[idx]))
                    
            ax_sn_f0_binned.errorbar(mean_f0,mean_sn,sd_sn,sd_f0,'o-', color = cmap(coloridx), label= label_now)
            ax_noise_f0_binned.errorbar(mean_f0,mean_noise,sd_noise,sd_f0,'o-', color = cmap(coloridx), label= label_now)
            ax_peakampl_f0_binned.errorbar(mean_f0,mean_ampl,sd_ampl,sd_f0,'o-', color = cmap(coloridx), label= label_now)
                    
    ax_sn_f0.set_xlabel('F0')
    ax_sn_f0.set_ylabel('S/N ratio')
    ax_sn_f0_binned.set_xlabel('F0')
    ax_sn_f0_binned.set_ylabel('S/N ratio')
    
    #ax_sn_f0_binned.legend()
    ax_sn_f0_binned.legend(loc='upper center', bbox_to_anchor=(-.45, 1.5), shadow=True, ncol=1)

    
    
    ax_noise_f0.set_xlabel('F0')
    ax_noise_f0.set_ylabel('Noise (std(dF/F))')
    ax_noise_f0_binned.set_xlabel('F0')
    ax_noise_f0_binned.set_ylabel('Noise (std(dF/F))')
    
    ax_peakampl_f0.set_xlabel('F0')
    ax_peakampl_f0.set_ylabel('Peak amplitude (dF/F)')
    ax_peakampl_f0_binned.set_xlabel('F0')
    ax_peakampl_f0_binned.set_ylabel('Peak amplitude (dF/F)')
    
    #%%
    cells['SN']=snratio
    cells['V0']=v0s
    cells['holding']=holdings
    cells['RS']=np.asarray(rss,float)
    print(cells)
    cells = cells[cells['V0']<v0_max]
    cells = cells[cells['holding']>holding_min]
    print(cells)
    #% S/N ratio histogram
    fig=plt.figure()
    ax_hist = fig.add_axes([0,0,1,1])
    ax_hist.hist(cells['SN'].values)
    ax_hist.set_xlabel('S/N ratio of first 50 spikes')
    ax_hist.set_ylabel('# of cells')
    ax_hist.set_title(roi_type.replace('_',' '))
    ax_hist.set_xlim([0,15])
Beispiel #5
0
    elif virus_id == 240:
        virus = 'Voltron 2'
    else:
        virus = '??'

    squarepulses = ephysanal.SquarePulse()&cell&'square_pulse_amplitude < {}'.format(min_current)&'square_pulse_length > {}'.format(min_pulse_time)
    Rins = list()
    RSs = list()
    v0s = list()
    for squarepulse in squarepulses:
        trace = (ephys_patch.SweepResponse()&squarepulse).fetch1('response_trace')
        stim = (ephys_patch.SweepStimulus()&squarepulse).fetch1('stimulus_trace')
        sr = (ephys_patch.SweepMetadata()&squarepulse).fetch1('sample_rate')
        step_back = int(integration_window*sr)
        if step_back<squarepulse['square_pulse_start_idx'] and np.abs(np.median(stim[:step_back]))<=max_baseline_current and np.median(trace[:step_back])<max_v0:
            RS = float((ephysanal.SweepSeriesResistance()&squarepulse).fetch1('series_resistance'))
            RS_residual = float((ephysanal.SweepSeriesResistance()&squarepulse).fetch1('series_resistance_residual'))
            baseline_v = np.median(trace[squarepulse['square_pulse_start_idx']-step_back:squarepulse['square_pulse_start_idx']])
            Rin_v = np.median(trace[squarepulse['square_pulse_end_idx']-step_back:squarepulse['square_pulse_end_idx']])
            dv = Rin_v-baseline_v
            di = squarepulse['square_pulse_amplitude']
            Rin = dv/di/1000000 - RS_residual
            Rins.append(Rin)
            RSs.append(RS)
            v0s.append(baseline_v*1000)
            #break
        
    if len(Rins)>1: 
        rs,threshold,baseline,hw,amplitude = (ephysanal.SweepSeriesResistance()*ephysanal.ActionPotential()*ephysanal.ActionPotentialDetails()&cell&'ap_real=1').fetch('series_resistance','ap_threshold','ap_baseline_value','ap_halfwidth','ap_amplitude')
        needed = (rs<AP_max_RS) & (baseline<AP_max_baseline)
        if sum(needed)>=100:
Beispiel #6
0
         motion_corr_vectors = np.asarray((imaging.MotionCorrection()*imaging.RegisteredMovie()&movie&'motion_correction_method  = "VolPy"'&'motion_corr_description= "rigid motion correction done with VolPy"').fetch1('motion_corr_vectors'))
         movie_dict['movie_start_time'] = frame_times[0]        
         movie_files = list()
         repositories , directories , fnames = (imaging.MovieFile() & movie).fetch('movie_file_repository','movie_file_directory','movie_file_name')
         for repository,directory,fname in zip(repositories,directories,fnames):
             movie_files.append(os.path.join(dj.config['locations.{}'.format(repository)],directory,fname))
         sweepdata_out = list()
         sweepmetadata_out = list()
         for sweep_number in sweep_numbers:
             #%
             key_sweep = key.copy()
             key_sweep['sweep_number'] = sweep_number
 
             neutralizationenable,e_sr= (ephys_patch.SweepMetadata()&key_sweep).fetch1('neutralizationenable','sample_rate')
             try:
                 uncompensatedRS =  float((ephysanal.SweepSeriesResistance()&key_sweep).fetch1('series_resistance_residual'))
             except:
                 uncompensatedRS = 0
             v = (ephys_patch.SweepResponse()&key_sweep).fetch1('response_trace')
             i = (ephys_patch.SweepStimulus()&key_sweep).fetch1('stimulus_trace')
             tau_1_on =.1/1000
             t = np.arange(0,.001,1/e_sr)
             f_on = np.exp(t/tau_1_on) 
             f_on = f_on/np.max(f_on)
             kernel = np.concatenate([f_on,np.zeros(len(t))])[::-1]
             kernel  = kernel /sum(kernel )  
             i_conv = np.convolve(i,kernel,'same')
             v_comp = (v - i*uncompensatedRS*10**6)*1000 - junction_potential
             i = i * 10**12
             
             sweep_start_time  = float((ephys_patch.Sweep()&key_sweep).fetch('sweep_start_time')) 
Beispiel #7
0
    key_cell = dict(cell)
    del key_cell['Freq']
    snratios = (imaging_gt.GroundTruthROI() * imaging_gt.ROIAPWave()
                & key_cell).fetch('apwave_snratio')
    sweep = (imaging_gt.GroundTruthROI() * imaging_gt.ROIAPWave()
             & key_cell).fetch('sweep_number')[0]
    trace = (ephys_patch.SweepResponse() * imaging_gt.GroundTruthROI()
             & key_cell
             & 'sweep_number = {}'.format(sweep)).fetch('response_trace')
    trace = trace[0]
    stimulus = (ephys_patch.SweepStimulus() * imaging_gt.GroundTruthROI()
                & key_cell
                & 'sweep_number = {}'.format(sweep)).fetch('stimulus_trace')
    stimulus = stimulus[0]

    RS = (ephysanal.SweepSeriesResistance() * imaging_gt.GroundTruthROI()
          & key_cell
          & 'sweep_number = {}'.format(sweep)).fetch('series_resistance')
    RS = RS[0]

    medianvoltage = np.median(trace) * 1000
    holding = np.median(stimulus) * 10**12
    #print(np.mean(snratios[:100]))
    snratio.append(np.mean(snratios[:50]))
    v0s.append(medianvoltage)
    holdings.append(holding)
    rss.append(RS)
#%
cells['SN'] = snratio
cells['V0'] = v0s
cells['holding'] = holdings
Beispiel #8
0
def plot_cell_SN_ratio_APwise(roi_type='VolPy',
                              v0_max=-35,
                              holding_min=-600,
                              frame_rate_min=300,
                              frame_rate_max=1800,
                              F0_min=50,
                              bin_num=10):

    #%% Show S/N ratios for each AP

    bin_num = 10
    holding_min = -600  #pA
    v0_max = -35  #mV
    roi_type = 'VolPy_raw'  #'Spikepursuit'#'VolPy_denoised'#'SpikePursuit'#'VolPy_dexpF0'#'VolPy'#'SpikePursuit_dexpF0'#'VolPy_dexpF0'#''Spikepursuit'#'VolPy'#
    F0_min = 50
    frame_rate_min = 200
    frame_rate_max = 800

    cmap = cm.get_cmap('jet')

    key = {'roi_type': roi_type}
    gtdata = pd.DataFrame((imaging_gt.GroundTruthROI() & key))
    cells = gtdata.groupby([
        'session', 'subject_id', 'cell_number', 'motion_correction_method',
        'roi_type'
    ]).size().reset_index(name='Freq')
    snratio = list()
    v0s = list()
    holdings = list()
    rss = list()
    threshs = list()
    mintreshs = list()
    f0s_all = list()
    snratios_all = list()
    peakamplitudes_all = list()
    noise_all = list()
    for cell in cells.iterrows():
        cell = cell[1]
        key_cell = dict(cell)
        del key_cell['Freq']
        snratios, f0, peakamplitudes, noises = (
            imaging.Movie() * imaging_gt.GroundTruthROI() *
            imaging_gt.ROIAPWave() * ephysanal.ActionPotentialDetails()
            & key_cell & 'ap_real = 1'
            & 'movie_frame_rate > {}'.format(frame_rate_min)
            & 'movie_frame_rate < {}'.format(frame_rate_max)
            & 'apwave_f0 > {}'.format(F0_min)).fetch('apwave_snratio',
                                                     'apwave_f0',
                                                     'apwave_peak_amplitude',
                                                     'apwave_noise')
        #f0 =  (imaging_gt.GroundTruthROI()*imaging.ROI()&key_cell).fetch('roi_f0')

        sweep = (imaging_gt.GroundTruthROI() * imaging_gt.ROIAPWave()
                 & key_cell).fetch('sweep_number')[0]
        thresh = (imaging_gt.GroundTruthROI() * imaging_gt.ROIAPWave() *
                  ephysanal.ActionPotentialDetails() & key_cell
                  & 'ap_real = 1').fetch('ap_threshold')
        trace = (ephys_patch.SweepResponse() * imaging_gt.GroundTruthROI()
                 & key_cell
                 & 'sweep_number = {}'.format(sweep)).fetch('response_trace')
        trace = trace[0]
        stimulus = (ephys_patch.SweepStimulus() * imaging_gt.GroundTruthROI()
                    & key_cell &
                    'sweep_number = {}'.format(sweep)).fetch('stimulus_trace')
        stimulus = stimulus[0]

        RS = (ephysanal.SweepSeriesResistance() * imaging_gt.GroundTruthROI()
              & key_cell
              & 'sweep_number = {}'.format(sweep)).fetch('series_resistance')
        RS = RS[0]

        medianvoltage = np.median(trace) * 1000
        holding = np.median(stimulus) * 10**12
        #print(np.mean(snratios[:100]))
        snratio.append(np.mean(snratios[:50]))
        v0s.append(medianvoltage)
        holdings.append(holding)
        rss.append(RS)
        threshs.append(thresh)
        mintreshs.append(np.min(thresh))
        f0s_all.append(f0)
        snratios_all.append(snratios)
        peakamplitudes_all.append(peakamplitudes)
        noise_all.append(noises)
        #plot_AP_waveforms(key_cell,AP_tlimits)
        #%%
# =============================================================================
#     virus_list=list()
#     subject_ids = list()
#     for cell in cells.iterrows():
#         cell = cell[1]
#         key_cell = dict(cell)
#         del key_cell['Freq']
#         virus_id = (lab.Surgery.VirusInjection()&'subject_id = {}'.format(key_cell['subject_id'])).fetch('virus_id')[0]
#         if virus_id == 238:
#             virus = 'Voltron 1'
#         elif virus_id == 240:
#             virus = 'Voltron 2'
#         virus_list.append(virus)
#         subject_ids.append(key_cell['subject_id'])
#    # order = np.argsort(virus_list)
#     order = np.lexsort((virus_list, subject_ids))
#     snratio = np.asarray(snratio)[order]
#     v0s = np.asarray(v0s)[order]
#     holdings = np.asarray(holdings)[order]
#     rss = np.asarray(rss)[order]
#     threshs =np.asarray(threshs)[order]
#     mintreshs = np.asarray(mintreshs)[order]
#     f0s_all = np.asarray(f0s_all)[order]
#     snratios_all = np.asarray(snratios_all)[order]
#     peakamplitudes_all = np.asarray(peakamplitudes_all)[order]
#     noise_all = np.asarray(noise_all)[order]
#     virus_list = np.asarray(virus_list)[order]
#     cells = cells.set_index(order, append=True).sort_index(level=1).reset_index(1, drop=True)
# =============================================================================
#%%
    apnum = 50
    fig = plt.figure(figsize=[10, 10])
    ax_exptime_f0 = fig.add_subplot(221)
    ax_exptime_dff = fig.add_subplot(222)
    ax_exptime_noise = fig.add_subplot(223)
    ax_exptime_snration = fig.add_subplot(224)
    for loopidx, (f0, snratio_now, noise_now, peakampl_now,
                  cell_now) in enumerate(
                      zip(f0s_all, snratios_all, noise_all, peakamplitudes_all,
                          cells.iterrows())):
        if len(
                f0
        ) > 0:  # and cell_now[1]['subject_id']==466774:# and cell_now[1]['cell_number']==1:
            coloridx = loopidx / len(cells)
            cell_now = cell_now[1]
            virus_id = (lab.Surgery.VirusInjection()
                        & 'subject_id = {}'.format(
                            cell_now['subject_id'])).fetch('virus_id')[0]
            if virus_id == 238:
                virus = 'Voltron 1'
            elif virus_id == 240:
                virus = 'Voltron 2'
            else:
                virus = '??'
            label_now = 'Subject:{}'.format(
                cell_now['subject_id']) + ' Cell:{} - {}'.format(
                    cell_now['cell_number'], virus)

            expression_time = np.diff(
                (lab.Surgery() & 'subject_id = {}'.format(
                    cell_now['subject_id'])).fetch('start_time'))[0].days
            ax_exptime_f0.plot(expression_time,
                               np.mean(f0[:apnum]),
                               'o',
                               ms=10,
                               color=cmap(coloridx),
                               label=label_now)
            ax_exptime_f0.errorbar(expression_time,
                                   np.mean(f0[:apnum]),
                                   np.std(f0[:apnum]),
                                   ecolor=cmap(coloridx))
            ax_exptime_f0.set_xlabel('expression time (days)')
            ax_exptime_f0.set_ylabel('F0 (pixel intensity)')

            ax_exptime_dff.plot(expression_time,
                                np.mean(peakampl_now[:apnum]),
                                'o',
                                ms=10,
                                color=cmap(coloridx),
                                label=label_now)
            ax_exptime_dff.errorbar(expression_time,
                                    np.mean(peakampl_now[:apnum]),
                                    np.std(peakampl_now[:apnum]),
                                    ecolor=cmap(coloridx))
            ax_exptime_dff.set_xlabel('expression time (days)')
            ax_exptime_dff.set_ylabel('AP peak amplitude (dF/F)')

            ax_exptime_noise.plot(expression_time,
                                  np.mean(noise_now[:apnum]),
                                  'o',
                                  ms=10,
                                  color=cmap(coloridx),
                                  label=label_now)
            ax_exptime_noise.errorbar(expression_time,
                                      np.mean(noise_now[:apnum]),
                                      np.std(noise_now[:apnum]),
                                      ecolor=cmap(coloridx))
            ax_exptime_noise.set_xlabel('expression time (days)')
            ax_exptime_noise.set_ylabel('noise (dF/F)')

            ax_exptime_snration.plot(expression_time,
                                     np.mean(snratio_now[:apnum]),
                                     'o',
                                     ms=10,
                                     color=cmap(coloridx),
                                     label=label_now)
            ax_exptime_snration.errorbar(expression_time,
                                         np.mean(snratio_now[:apnum]),
                                         np.std(snratio_now[:apnum]),
                                         ecolor=cmap(coloridx))
            ax_exptime_snration.set_xlabel('expression time (days)')
            ax_exptime_snration.set_ylabel('S/N ratio')
    #%%  for each AP

    fig = plt.figure()
    ax_sn_f0 = fig.add_axes([0, 0, 1, 1])
    ax_sn_f0_binned = fig.add_axes([0, -1.2, 1, 1])
    ax_noise_f0 = fig.add_axes([2.6, 0, 1, 1])
    ax_noise_f0_binned = fig.add_axes([2.6, -1.2, 1, 1])
    ax_peakampl_f0 = fig.add_axes([1.3, 0, 1, 1])
    ax_peakampl_f0_binned = fig.add_axes([1.3, -1.2, 1, 1])
    for loopidx, (f0, snratio_now, noise_now, peakampl_now,
                  cell_now) in enumerate(
                      zip(f0s_all, snratios_all, noise_all, peakamplitudes_all,
                          cells.iterrows())):
        if len(
                f0
        ) > 0:  # and cell_now[1]['subject_id']==466774:# and cell_now[1]['cell_number']==1:
            coloridx = loopidx / len(cells)
            cell_now = cell_now[1]
            virus_id = (lab.Surgery.VirusInjection()
                        & 'subject_id = {}'.format(
                            cell_now['subject_id'])).fetch('virus_id')[0]
            if virus_id == 238:
                virus = 'Voltron 1'
            elif virus_id == 240:
                virus = 'Voltron 2'
            else:
                virus = '??'
            label_now = 'Subject:{}'.format(
                cell_now['subject_id']) + ' Cell:{} - {}'.format(
                    cell_now['cell_number'], virus)
            ax_sn_f0.plot(f0,
                          snratio_now,
                          'o',
                          ms=1,
                          color=cmap(coloridx),
                          label=label_now)
            ax_noise_f0.plot(f0,
                             noise_now,
                             'o',
                             ms=1,
                             color=cmap(coloridx),
                             label=label_now)
            ax_peakampl_f0.plot(f0,
                                peakampl_now,
                                'o',
                                ms=1,
                                color=cmap(coloridx),
                                label=label_now)
            lows = np.arange(np.min(f0), np.max(f0),
                             (np.max(f0) - np.min(f0)) / (bin_num + 1))
            highs = lows + (np.max(f0) - np.min(f0)) / (bin_num + 1)
            mean_f0 = list()
            sd_f0 = list()
            mean_sn = list()
            sd_sn = list()
            mean_noise = list()
            sd_noise = list()
            mean_ampl = list()
            sd_ampl = list()
            for low, high in zip(lows, highs):
                idx = (f0 >= low) & (f0 < high)
                if len(idx) > 10:
                    mean_f0.append(np.mean(f0[idx]))
                    sd_f0.append(np.std(f0[idx]))
                    mean_sn.append(np.mean(snratio_now[idx]))
                    sd_sn.append(np.std(snratio_now[idx]))
                    mean_noise.append(np.mean(noise_now[idx]))
                    sd_noise.append(np.std(noise_now[idx]))
                    mean_ampl.append(np.mean(peakampl_now[idx]))
                    sd_ampl.append(np.std(peakampl_now[idx]))

            ax_sn_f0_binned.errorbar(mean_f0,
                                     mean_sn,
                                     sd_sn,
                                     sd_f0,
                                     'o-',
                                     color=cmap(coloridx),
                                     label=label_now)
            ax_noise_f0_binned.errorbar(mean_f0,
                                        mean_noise,
                                        sd_noise,
                                        sd_f0,
                                        'o-',
                                        color=cmap(coloridx),
                                        label=label_now)
            ax_peakampl_f0_binned.errorbar(mean_f0,
                                           mean_ampl,
                                           sd_ampl,
                                           sd_f0,
                                           'o-',
                                           color=cmap(coloridx),
                                           label=label_now)

    ax_sn_f0.set_xlabel('F0')
    ax_sn_f0.set_ylabel('S/N ratio')
    ax_sn_f0_binned.set_xlabel('F0')
    ax_sn_f0_binned.set_ylabel('S/N ratio')

    #ax_sn_f0_binned.legend()
    ax_sn_f0_binned.legend(loc='upper center',
                           bbox_to_anchor=(-.45, 1.5),
                           shadow=True,
                           ncol=1)

    ax_noise_f0.set_xlabel('F0')
    ax_noise_f0.set_ylabel('Noise (std(dF/F))')
    ax_noise_f0_binned.set_xlabel('F0')
    ax_noise_f0_binned.set_ylabel('Noise (std(dF/F))')

    ax_peakampl_f0.set_xlabel('F0')
    ax_peakampl_f0.set_ylabel('Peak amplitude (dF/F)')
    ax_peakampl_f0_binned.set_xlabel('F0')
    ax_peakampl_f0_binned.set_ylabel('Peak amplitude (dF/F)')

    #%%
    cells['SN'] = snratio
    cells['V0'] = v0s
    cells['holding'] = holdings
    cells['RS'] = np.asarray(rss, float)
    print(cells)
    cells = cells[cells['V0'] < v0_max]
    cells = cells[cells['holding'] > holding_min]
    print(cells)
    #% S/N ratio histogram
    fig = plt.figure()
    ax_hist = fig.add_axes([0, 0, 1, 1])
    ax_hist.hist(cells['SN'].values)
    ax_hist.set_xlabel('S/N ratio of first 50 spikes')
    ax_hist.set_ylabel('# of cells')
    ax_hist.set_title(roi_type.replace('_', ' '))
    ax_hist.set_xlim([0, 15])