Esempio n. 1
0
def plot_sweep_set_summary(nwb_file,
                           highlight_sweep_number,
                           sweep_numbers,
                           highlight_color='#0779BE',
                           background_color='#dddddd'):

    fig = plt.figure(frameon=False)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    ax.set_yticklabels([])
    ax.set_xticklabels([])
    ax.set_xlabel('')
    ax.set_ylabel('')

    for sn in sweep_numbers:
        v, i, t, r, dt = load_experiment(nwb_file, sn)
        ax.plot(t, v, linewidth=0.5, color=background_color)

    v, i, t, r, dt = load_experiment(nwb_file, highlight_sweep_number)
    plt.plot(t, v, linewidth=1, color=highlight_color)

    stim_start, stim_dur, stim_amp, start_idx, end_idx = get_square_stim_characteristics(
        i, t)

    tstart = stim_start - 0.05
    tend = stim_start + stim_dur + 0.25

    ax.set_ylim(AXIS_Y_RANGE[0], AXIS_Y_RANGE[1])
    ax.set_xlim(tstart, tend)

    return fig
Esempio n. 2
0
def plot_sag_figures(nwb_file, cell_features, lims_features, sweep_features,
                     image_dir, sizes, cell_image_files):
    fig = plt.figure()
    for d in cell_features["long_squares"]["subthreshold_sweeps"]:
        if d['peak_deflect'][0] == lims_features["vm_for_sag"]:
            v, i, t, r, dt = load_experiment(nwb_file, int(d['id']))
            stim_start, stim_dur, stim_amp, start_idx, end_idx = get_square_stim_characteristics(
                i, t)
            plt.plot(t, v, color='black')
            plt.scatter(d['peak_deflect'][1],
                        d['peak_deflect'][0],
                        color='red',
                        zorder=10)
            #plt.plot([stim_start + stim_dur - 0.1, stim_start + stim_dur], [d['steady'], d['steady']], color='red', zorder=10)
    plt.xlim(stim_start - 0.25, stim_start + stim_dur + 0.25)
    plt.title("sag = {:.3g}".format(lims_features['sag']))
    plt.tight_layout()

    save_figure(fig,
                'sag',
                'sag',
                image_dir,
                sizes,
                cell_image_files,
                scalew=2)
Esempio n. 3
0
def plot_fi_curve_figures(nwb_file, cell_features, lims_features,
                          sweep_features, image_dir, sizes, cell_image_files):
    fig = plt.figure()
    fi_sorted = sorted(cell_features["long_squares"]["spiking_sweeps"],
                       key=lambda s: s['stim_amp'])
    x = [d['stim_amp'] for d in fi_sorted]
    y = [d['avg_rate'] for d in fi_sorted]
    last_zero_idx = np.nonzero(y)[0][0] - 1
    plt.scatter(x, y, color='black')
    plt.plot(x[last_zero_idx:],
             cell_features["long_squares"]["fi_fit_slope"] *
             (np.array(x[last_zero_idx:]) - x[last_zero_idx]),
             color='red')
    plt.xlabel("pA")
    plt.ylabel("spikes/sec")
    plt.title("slope = {:.3g}".format(lims_features["f_i_curve_slope"]))
    rheo_hero_sweeps = [
        int(lims_features["rheobase_sweep_num"]),
        int(lims_features["thumbnail_sweep_num"])
    ]
    rheo_hero_x = []
    for s in rheo_hero_sweeps:
        v, i, t, r, dt = load_experiment(nwb_file, s)
        stim_start, stim_dur, stim_amp, start_idx, end_idx = get_square_stim_characteristics(
            i, t)
        rheo_hero_x.append(stim_amp)
    rheo_hero_y = [
        len(get_spikes(sweep_features, s)) for s in rheo_hero_sweeps
    ]
    plt.scatter(rheo_hero_x, rheo_hero_y, zorder=20)
    plt.tight_layout()

    save_figure(fig,
                'fi_curve',
                'fi_curve',
                image_dir,
                sizes,
                cell_image_files,
                scalew=2)
Esempio n. 4
0
def plot_instantaneous_threshold_thumbnail(nwb_file,
                                           sweep_numbers,
                                           cell_features,
                                           lims_features,
                                           sweep_features,
                                           color='red'):
    min_sweep_number = None
    for sn in sorted(sweep_numbers):
        spikes = get_spikes(sweep_features, sn)

        if len(spikes) > 0:
            min_sweep_number = sn if min_sweep_number is None else min(
                min_sweep_number, sn)

    fig = plt.figure(frameon=False)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    ax.set_yticklabels([])
    ax.set_xticklabels([])
    ax.set_xlabel('')
    ax.set_ylabel('')

    v, i, t, r, dt = load_experiment(nwb_file, sn)
    stim_start, stim_dur, stim_amp, start_idx, end_idx = get_square_stim_characteristics(
        i, t)

    tstart = stim_start - 0.002
    tend = stim_start + stim_dur + 0.005
    tscale = 0.005

    plt.plot(t, v, linewidth=1, color=color)

    plt.ylim(AXIS_Y_RANGE[0], AXIS_Y_RANGE[1])
    plt.xlim(tstart, tend)

    return fig
Esempio n. 5
0
def plot_single_ap_values(nwb_file, sweep_numbers, lims_features,
                          sweep_features, cell_features, type_name):
    figs = [plt.figure() for f in range(3 + len(sweep_numbers))]

    v, i, t, r, dt = load_experiment(nwb_file, sweep_numbers[0])
    if type_name == "short_square" or type_name == "long_square":
        stim_start, stim_dur, stim_amp, start_idx, end_idx = get_square_stim_characteristics(
            i, t)
    elif type_name == "ramp":
        stim_start, start_idx = get_ramp_stim_characteristics(i, t)

    gen_features = [
        "threshold", "peak", "trough", "fast_trough", "slow_trough"
    ]
    voltage_features = [
        "threshold_v", "peak_v", "trough_v", "fast_trough_v", "slow_trough_v"
    ]
    time_features = [
        "threshold_t", "peak_t", "trough_t", "fast_trough_t", "slow_trough_t"
    ]

    for sn in sweep_numbers:
        spikes = get_spikes(sweep_features, sn)

        if (len(spikes) < 1):
            logging.warning("no spikes in sweep %d" % sn)
            continue

        if type_name != "long_square":
            voltages = [spikes[0][f] for f in voltage_features]
            times = [spikes[0][f] for f in time_features]
        else:
            rheo_sn = cell_features["long_squares"]["rheobase_sweep"]["id"]
            rheo_spike = get_spikes(sweep_features, rheo_sn)[0]
            voltages = [rheo_spike[f] for f in voltage_features]
            times = [rheo_spike[f] for f in time_features]

        plt.figure(figs[0].number)
        plt.scatter(range(len(voltages)), voltages, color='gray')
        plt.tight_layout()

        plt.figure(figs[1].number)
        plt.scatter(range(len(times)), times, color='gray')
        plt.tight_layout()

        plt.figure(figs[2].number)
        plt.scatter([0], [spikes[0]['upstroke'] / (-spikes[0]['downstroke'])],
                    color='gray')
        plt.tight_layout()

    plt.figure(figs[0].number)

    yvals = [
        float(lims_features[k + "_v_" + type_name]) for k in gen_features
        if lims_features[k + "_v_" + type_name] is not None
    ]
    xvals = range(len(yvals))

    plt.scatter(xvals, yvals, color='blue', marker='_', s=40, zorder=100)
    plt.xticks(xvals, ['thr', 'pk', 'tr', 'ftr', 'str'])
    plt.title(type_name + ": voltages")

    plt.figure(figs[1].number)
    yvals = [
        float(lims_features[k + "_t_" + type_name]) for k in gen_features
        if lims_features[k + "_t_" + type_name] is not None
    ]
    xvals = range(len(yvals))
    plt.scatter(xvals, yvals, color='blue', marker='_', s=40, zorder=100)
    plt.xticks(xvals, ['thr', 'pk', 'tr', 'ftr', 'str'])
    plt.title(type_name + ": times")

    plt.figure(figs[2].number)
    if lims_features["upstroke_downstroke_ratio_" + type_name] is not None:
        plt.scatter(
            [0],
            [float(lims_features["upstroke_downstroke_ratio_" + type_name])],
            color='blue',
            marker='_',
            s=40,
            zorder=100)
    plt.xticks([])
    plt.title(type_name + ": up/down")

    for index, sn in enumerate(sweep_numbers):
        plt.figure(figs[3 + index].number)

        v, i, t, r, dt = load_experiment(nwb_file, sn)
        plt.plot(t, v, color='black')
        plt.title(str(sn))

        spikes = get_spikes(sweep_features, sn)

        nspikes = len(spikes)

        if type_name != "long_square" and nspikes:
            if nspikes == 0:
                logging.warning("no spikes in sweep %d" % sn)
                continue

            voltages = [spikes[0][f] for f in voltage_features]
            times = [spikes[0][f] for f in time_features]
        else:
            rheo_sn = cell_features["long_squares"]["rheobase_sweep"]["id"]
            rheo_spike = get_spikes(sweep_features, rheo_sn)[0]
            voltages = [rheo_spike[f] for f in voltage_features]
            times = [rheo_spike[f] for f in time_features]

        plt.scatter(times, voltages, color='red', zorder=20)

        delta_v = 5.0
        if nspikes:
            plt.plot([
                spikes[0]['upstroke_t'] - 1e-3 *
                (delta_v / spikes[0]['upstroke']), spikes[0]['upstroke_t'] +
                1e-3 * (delta_v / spikes[0]['upstroke'])
            ], [
                spikes[0]['upstroke_v'] - delta_v,
                spikes[0]['upstroke_v'] + delta_v
            ],
                     color='red')

            if 'downstroke_t' in spikes[0]:
                plt.plot([
                    spikes[0]['downstroke_t'] - 1e-3 *
                    (delta_v / spikes[0]['downstroke']),
                    spikes[0]['downstroke_t'] + 1e-3 *
                    (delta_v / spikes[0]['downstroke'])
                ], [
                    spikes[0]['downstroke_v'] - delta_v,
                    spikes[0]['downstroke_v'] + delta_v
                ],
                         color='red')
            else:
                logging.warning("spike has no downstroke time, clipped")

        if type_name == "ramp":
            if nspikes:
                plt.xlim(spikes[0]["threshold_t"] - 0.002,
                         spikes[0]["fast_trough_t"] + 0.01)
        elif type_name == "short_square":
            plt.xlim(stim_start - 0.002, stim_start + stim_dur + 0.01)
        elif type_name == "long_square":
            plt.xlim(times[0] - 0.002, times[-2] + 0.002)

        plt.tight_layout()

    return figs
Esempio n. 6
0
def plot_hero_figures(nwb_file, cell_features, lims_features, sweep_features,
                      image_dir, sizes, cell_image_files):
    fig = plt.figure()
    v, i, t, r, dt = load_experiment(nwb_file,
                                     int(lims_features["thumbnail_sweep_num"]))
    plt.plot(t, v, color='black')
    stim_start, stim_dur, stim_amp, start_idx, end_idx = get_square_stim_characteristics(
        i, t)
    plt.xlim(stim_start - 0.05, stim_start + stim_dur + 0.05)
    plt.ylim(-110, 50)
    spike_times = [
        spk['threshold_t'] for spk in get_spikes(
            sweep_features, lims_features["thumbnail_sweep_num"])
    ]
    isis = np.diff(np.array(spike_times))
    plt.title("thumbnail {:d}, amp = {:.1f}".format(
        lims_features["thumbnail_sweep_num"], stim_amp))
    plt.tight_layout()

    save_figure(fig,
                'thumbnail_0',
                'thumbnail',
                image_dir,
                sizes,
                cell_image_files,
                scalew=2)

    fig = plt.figure()
    plt.plot(range(len(isis)), isis)
    plt.ylabel("ISI (ms)")
    if lims_features.get("adaptation", None) is not None:
        plt.title("adapt = {:.3g}".format(lims_features["adaptation"]))
    else:
        plt.title("adapt = not defined")

    for k in ["has_delay", "has_burst", "has_pause"]:
        if lims_features.get(k, None) is None:
            lims_features[k] = False

    plt.tight_layout()
    save_figure(fig, 'thumbnail_1', 'thumbnail', image_dir, sizes,
                cell_image_files)

    yvals = [
        float(lims_features["has_delay"]),
        float(lims_features["has_burst"]),
        float(lims_features["has_pause"]),
    ]
    xvals = range(len(yvals))

    fig = plt.figure()
    plt.scatter(xvals, yvals, color='red')
    plt.xticks(xvals, ['Delay', 'Burst', 'Pause'])
    plt.title("flags")
    plt.tight_layout()

    save_figure(fig, 'thumbnail_2', 'thumbnail', image_dir, sizes,
                cell_image_files)

    summary_fig = plot_long_square_summary(nwb_file, cell_features,
                                           lims_features, sweep_features)
    save_figure(summary_fig,
                'ephys_summary',
                'thumbnail',
                image_dir,
                sizes,
                cell_image_files,
                scalew=2)
Esempio n. 7
0
def plot_subthreshold_long_square_figures(nwb_file, cell_features,
                                          lims_features, sweep_features,
                                          image_dir, sizes, cell_image_files):
    lsq_sweeps = cell_features["long_squares"]["sweeps"]
    sub_sweeps = cell_features["long_squares"]["subthreshold_sweeps"]
    tau_sweeps = cell_features["long_squares"][
        "subthreshold_membrane_property_sweeps"]

    # 0a - Plot VI curve and linear fit, along with vrest
    x = np.array([s['stim_amp'] for s in sub_sweeps])
    y = np.array([s['peak_deflect'][0] for s in sub_sweeps])
    i = np.array([s['stim_amp'] for s in tau_sweeps])

    fig = plt.figure()
    plt.scatter(x, y, color='black')
    plt.plot([x.min(), x.max()],
             [lims_features["vrest"], lims_features["vrest"]],
             color="blue",
             linewidth=2)
    plt.plot(i,
             i * 1e-3 * lims_features["ri"] + lims_features["vrest"],
             color="red",
             linewidth=2)
    plt.xlabel("pA")
    plt.ylabel("mV")
    plt.title("ri = {:.1f}, vrest = {:.1f}".format(lims_features["ri"],
                                                   lims_features["vrest"]))
    plt.tight_layout()

    save_figure(fig, 'VI_curve', 'subthreshold_long_squares', image_dir, sizes,
                cell_image_files)

    # 0b - Plot tau curve and average
    fig = plt.figure()
    x = np.array([s['stim_amp'] for s in tau_sweeps])
    y = np.array([s['tau'] for s in tau_sweeps])
    plt.scatter(x, y, color='black')
    i = np.array([s['stim_amp'] for s in tau_sweeps])
    plt.plot([i.min(), i.max()], [
        cell_features["long_squares"]["tau"],
        cell_features["long_squares"]["tau"]
    ],
             color="red",
             linewidth=2)
    plt.xlabel("pA")
    ylim = plt.ylim()
    plt.ylim(0, ylim[1])
    plt.ylabel("tau (s)")
    plt.tight_layout()

    save_figure(fig, 'tau_curve', 'subthreshold_long_squares', image_dir,
                sizes, cell_image_files)

    subthresh_dict = {s['id']: s for s in tau_sweeps}

    # 0c - Plot the subthreshold squares
    tau_sweeps = [s['id'] for s in tau_sweeps]
    tau_figs = [plt.figure() for i in range(len(tau_sweeps))]

    for index, s in enumerate(tau_sweeps):
        v, i, t, r, dt = load_experiment(nwb_file, s)

        plt.figure(tau_figs[index].number)

        plt.plot(t, v, color="black")

        if index == 0:
            min_y, max_y = plt.ylim()
        else:
            ylims = plt.ylim()
            if min_y > ylims[0]:
                min_y = ylims[0]
            if max_y < ylims[1]:
                max_y = ylims[1]

        stim_start, stim_dur, stim_amp, start_idx, end_idx = get_square_stim_characteristics(
            i, t)
        plt.xlim(stim_start - 0.05, stim_start + stim_dur + 0.05)
        peak_idx = subthresh_dict[s]['peak_deflect'][1]
        peak_t = peak_idx * dt
        plt.scatter([peak_t], [subthresh_dict[s]['peak_deflect'][0]],
                    color='red',
                    zorder=10)
        popt = ft.fit_membrane_time_constant(v, t, stim_start, peak_t)
        plt.title(str(s))
        plt.plot(t[start_idx:peak_idx],
                 exp_curve(t[start_idx:peak_idx] - t[start_idx], *popt),
                 color='blue')

    for index, s in enumerate(tau_sweeps):
        plt.figure(tau_figs[index].number)
        plt.ylim(min_y, max_y)
        plt.tight_layout()

    for index, tau_fig in enumerate(tau_figs):
        save_figure(tau_figs[index], 'tau_%d' % index,
                    'subthreshold_long_squares', image_dir, sizes,
                    cell_image_files)
Esempio n. 8
0
def MLIN(voltage, current, res, cap, dt, MAKE_PLOT=False, SHOW_PLOT=False, BLOCK=False, PUBLICATION_PLOT=False):
    '''voltage, current
    input:
        voltage: numpy array of voltage with test pulse cut out 
        current: numpy array of stimulus with test pulse cut out '''
    t = np.arange(0, len(current)) * dt
    (_, _, _, start_idx, end_idx) = get_square_stim_characteristics(current, t, no_test_pulse=True)
    stim_len = end_idx - start_idx

    distribution_start_ind=start_idx + int(.5/dt)
    distribution_end_ind=start_idx + stim_len
    
    v_section=voltage[distribution_start_ind:distribution_end_ind]
    if MAKE_PLOT:
        times=np.arange(0, len(voltage))*dt
        plt.figure(figsize=(15, 11))    
        plt.subplot2grid((7,2), (0,0), colspan=2)
        plt.plot(times[distribution_start_ind:distribution_end_ind], v_section)
        plt.title('voltage for histogram')

    print(v_section)
    v_section=v_section-np.mean(v_section)  
    var_of_section=np.var(v_section)
    sv_for_expsymm=np.std(v_section)/np.sqrt(2)
    subthreshold_long_square_voltage_distribution=stats.norm(loc=0, scale=np.sqrt(var_of_section))
    
    #--autocorrelation
    tau_4AC=res*cap
    AC=autocorr(v_section-np.mean(v_section)) 
    ACtime=np.arange(0,len(AC))*dt

    #--fit autocorrelation with decaying exponential    
    (popt, pcov)= curve_fit(exp_decay, ACtime, AC, p0=[AC[0],tau_4AC])
    tau_from_AC=popt[1]
    
    if MAKE_PLOT:        
        plt.subplot2grid((7,2), (1,0), rowspan=3)    
        plt.hist(v_section, bins=50, normed=True, label='data')
        data_grid=np.arange(min(v_section), max(v_section), abs(min(v_section)-max(v_section))/100.)
        plt.plot(data_grid, subthreshold_long_square_voltage_distribution.pdf(data_grid), 'r', label='gauss with\nmeasured var')
        plt.plot(data_grid, expsymm_pdf(data_grid, sv_for_expsymm), 'm', lw=3, label='expsymm function')
        plt.xlabel('voltage (mV)')
        plt.title('Mean subtracted voltage hist')
        plt.legend()
        
        #--cumulative density function
        (h, edges)=np.histogram(v_section, bins=50)        
        centers=find_bin_center(edges)

        CDFx=centers
        CDFy=np.cumsum(h)/float(len(v_section))

        plt.subplot2grid((7,2), (4,0), rowspan=3)
        plt.plot(CDFx, CDFy, label='data')
#        plt.plot(CDFx, sig(CDFx, popt[0], popt[1]), label='fit')
        plt.plot(data_grid, subthreshold_long_square_voltage_distribution.cdf(data_grid), 'r', label='gauss with\nmeasured var')
        plt.plot(data_grid, expsymm_cdf(data_grid, sv_for_expsymm), 'm', lw=3, label='expsymm func')
        plt.title('Normalized cumulative sum')
        plt.xlabel('v-mean(v)')
        plt.legend()
        
        plt.subplot2grid((7,2), (1,1), rowspan=3)
        plt.plot(ACtime, AC, label='data')
        plt.xlabel('shift (s)')
        plt.title('Auto correlation')
        plt.plot(ACtime, exp_decay(ACtime, AC[0], tau_4AC), label='RC')
        plt.plot(ACtime, exp_decay(ACtime, popt[0], tau_from_AC), label='fit')
        plt.legend()
    
        plt.tight_layout()
        if SHOW_PLOT:
            plt.show(block=BLOCK)  

        if PUBLICATION_PLOT:     
            times=np.arange(0, len(voltage))*dt
            plt.figure(figsize=(14, 7))    
            plt.subplot2grid((3,3), (0,0), colspan=3)
            plt.xlabel('time (s)', fontsize=14)
            plt.ylabel('(mV)', fontsize=14)
            plt.plot(times[distribution_start_ind:distribution_end_ind], v_section*1.e3)
            plt.title('Voltage for histogram', fontsize=16)
               
            plt.subplot2grid((3,3), (1,0), rowspan=2)    
            plt.hist(v_section*1.e3, bins=50, normed=True, label='data')
            data_grid=np.arange(min(v_section), max(v_section), abs(min(v_section)-max(v_section))/100.)
#                    plt.plot(data_grid, subthreshold_long_square_voltage_distribution.pdf(data_grid), 'r', label='gauss with\nmeasured var')
            plt.plot(data_grid*1.e3, 1.e-3*expsymm_pdf(data_grid, sv_for_expsymm), 'm', lw=3, label='expsymm ')
            plt.xlabel('voltage (mV)', fontsize=14)
            plt.title('Mean subtracted voltage hist', fontsize=16)
            plt.legend(loc=1)
            
            #--cumulative density function
            (h, edges)=np.histogram(v_section, bins=50)
            centers=find_bin_center(edges)

            CDFx=centers
            CDFy=np.cumsum(h)/float(len(v_section))

            plt.subplot2grid((3,3), (1,1), rowspan=2)
            plt.plot(CDFx*1e3, CDFy, label='data')
    #        plt.plot(CDFx, sig(CDFx, popt[0], popt[1]), label='fit')
#                    plt.plot(data_grid, subthreshold_long_square_voltage_distribution.cdf(data_grid), 'r', label='gauss with\nmeasured var')
            plt.plot(data_grid*1.e3, expsymm_cdf(data_grid, sv_for_expsymm), 'm', lw=3, label='expsymm')
            plt.title('Normalized cumulative sum',  fontsize=16)
            plt.xlabel('V-mean(V) (mV)', fontsize=16)
            plt.legend(loc=2, fontsize=14)
            
            plt.subplot2grid((3,3), (1,2), rowspan=2)
            plt.plot(ACtime, AC*1.e3, label='data')
            plt.xlabel('shift (s)', fontsize=14)
            plt.title('Auto correlation',  fontsize=16)
#                    plt.plot(ACtime, exp_decay(ACtime, AC[0], tau_4AC), label='RC')
            plt.plot(ACtime, exp_decay(ACtime, popt[0]*1.e3, tau_from_AC), lw=3, label='fit')
            plt.legend(loc=1)
            plt.tight_layout()
            
    return var_of_section, sv_for_expsymm, tau_from_AC