示例#1
0
def validate_sweeps(data_set, sweep_numbers, extra_dur=0.2):
    check_sweeps = data_set.sweep_set(sweep_numbers)
    valid_sweep_stim = []
    start = None
    dur = None
    for swp in check_sweeps.sweeps:
        swp_start, swp_dur, _, _, _ = stf.get_stim_characteristics(
            swp.i, swp.t, False)
        if swp_start is None:
            valid_sweep_stim.append(False)
        else:
            start = swp_start
            dur = swp_dur
            valid_sweep_stim.append(True)
    if start is None:
        # Could not find any sweeps to define stimulus interval
        return [], None, None

    end = start + dur

    # Check that all sweeps are long enough and not ended early
    good_sweep_numbers = [
        n for n, s, v in zip(sweep_numbers, check_sweeps.sweeps,
                             valid_sweep_stim)
        if s.t[-1] >= end + extra_dur and v is True and not np.all(
            s.v[tsu.find_time_index(s.t, end) -
                100:tsu.find_time_index(s.t, end)] == 0)
    ]
    return good_sweep_numbers, start, end
示例#2
0
def feature_vector_input():

    TEST_DATA_PATH = os.path.join(os.path.dirname(__file__), 'data')

    nwb_file_name = "Pvalb-IRES-Cre;Ai14-415796.02.01.01.nwb"
    nwb_file_full_path = os.path.join(TEST_DATA_PATH, nwb_file_name)

    if not os.path.exists(nwb_file_full_path):
        download_file(nwb_file_name, nwb_file_full_path)

    data_set = AibsDataSet(nwb_file=nwb_file_full_path, ontology=ontology)

    lsq_sweep_numbers = data_set.filtered_sweep_table(
        clamp_mode=data_set.CURRENT_CLAMP,
        stimuli=ontology.long_square_names).sweep_number.sort_values().values

    lsq_sweeps = data_set.sweep_set(lsq_sweep_numbers)
    lsq_start, lsq_dur, _, _, _ = stf.get_stim_characteristics(
        lsq_sweeps.sweeps[0].i, lsq_sweeps.sweeps[0].t)

    lsq_end = lsq_start + lsq_dur
    lsq_spx, lsq_spfx = dsf.extractors_for_sweeps(lsq_sweeps,
                                                  start=lsq_start,
                                                  end=lsq_end,
                                                  **dsf.detection_parameters(
                                                      data_set.LONG_SQUARE))
    lsq_an = spa.LongSquareAnalysis(lsq_spx, lsq_spfx, subthresh_min_amp=-100.)

    lsq_features = lsq_an.analyze(lsq_sweeps)

    return lsq_sweeps, lsq_features, lsq_start, lsq_end
示例#3
0
def plot_sag_figures(data_set, cell_features, lims_features, sweep_features,
                     image_dir, sizes, cell_image_files):
    fig = plt.figure()
    for d in cell_features["long_squares"]["subthreshold_sweeps"]:
        if d['peak_deflect'][0] == lims_features["vm_for_sag"]:
            v, i, t, r, dt = load_sweep(data_set, int(d['sweep_number']))
            stim_start, stim_dur, stim_amp, start_idx, end_idx = st.get_stim_characteristics(
                i, t)
            plt.plot(t, v, color='black')
            plt.scatter(d['peak_deflect'][1],
                        d['peak_deflect'][0],
                        color='red',
                        zorder=10)
            #plt.plot([stim_start + stim_dur - 0.1, stim_start + stim_dur], [d['steady'], d['steady']], color='red', zorder=10)
    plt.xlim(stim_start - 0.25, stim_start + stim_dur + 0.25)
    plt.title("sag = {:.3g}".format(lims_features['sag']))
    plt.tight_layout()

    save_figure(fig,
                'sag',
                'sag',
                image_dir,
                sizes,
                cell_image_files,
                scalew=2)
示例#4
0
def feature_vector_input():

    TEST_DATA_PATH = os.path.join(os.path.dirname(__file__), 'data')

    nwb_file_name = "Pvalb-IRES-Cre;Ai14-415796.02.01.01.nwb"
    nwb_file_full_path = os.path.join(TEST_DATA_PATH, nwb_file_name)

    if not os.path.exists(nwb_file_full_path):
        download_file(nwb_file_name, nwb_file_full_path)

    data_set = AibsDataSet(nwb_file=nwb_file_full_path, ontology=ontology)

    lsq_sweep_numbers = [4, 5, 6, 16, 17, 18, 19, 20, 21]

    lsq_sweeps = data_set.sweep_set(lsq_sweep_numbers)
    lsq_sweeps.select_epoch("recording")
    lsq_sweeps.align_to_start_of_epoch("experiment")
    lsq_start, lsq_dur, _, _, _ = stf.get_stim_characteristics(lsq_sweeps.sweeps[0].i,
                                                               lsq_sweeps.sweeps[0].t)

    lsq_end = lsq_start + lsq_dur
    lsq_spx, lsq_spfx = dsf.extractors_for_sweeps(lsq_sweeps,
                                                  start=lsq_start,
                                                  end=lsq_end,
                                                  **dsf.detection_parameters(data_set.LONG_SQUARE))
    lsq_an = spa.LongSquareAnalysis(lsq_spx, lsq_spfx, subthresh_min_amp=-100.)

    lsq_features = lsq_an.analyze(lsq_sweeps)

    return lsq_sweeps, lsq_features, lsq_start, lsq_end
示例#5
0
def plot_sweep_set_summary(data_set,
                           highlight_sweep_number,
                           sweep_numbers,
                           highlight_color='#0779BE',
                           background_color='#dddddd'):

    fig = plt.figure(frameon=False)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    ax.set_yticklabels([])
    ax.set_xticklabels([])
    ax.set_xlabel('')
    ax.set_ylabel('')

    for sn in sweep_numbers:
        v, i, t, r, dt = load_sweep(data_set, sn)
        ax.plot(t, v, linewidth=0.5, color=background_color)

    v, i, t, r, dt = load_sweep(data_set, highlight_sweep_number)
    plt.plot(t, v, linewidth=1, color=highlight_color)
    stim_start, stim_dur, stim_amp, start_idx, end_idx = st.get_stim_characteristics(
        i, t)

    tstart = stim_start - 0.05
    tend = stim_start + stim_dur + 0.25

    ax.set_ylim(AXIS_Y_RANGE[0], AXIS_Y_RANGE[1])
    ax.set_xlim(tstart, tend)

    return fig
示例#6
0
def test_get_stimuli_characteristics(i, test_pulse, stim_characteristics):
    t = np.arange(len(i))
    start_time, duration, amplitude, start_idx, end_idx = st.get_stim_characteristics(i, t, test_pulse=test_pulse)
    assert start_time == stim_characteristics[0]
    assert np.isclose(duration, stim_characteristics[1])
    assert np.isclose(amplitude, stim_characteristics[2])
    assert start_idx == stim_characteristics[3]
    assert end_idx == stim_characteristics[4]
 def calc_stimparams_ipfx(time, stimulus_trace, trace_name):
     start_time, duration, amplitude, start_idx, end_idx = get_stim_characteristics(
         stimulus_trace, time)
     amplitude *= 1e12
     stim_stop = start_time + duration
     stim_amp_start = 1e12 * stimulus_trace[start_idx]
     stim_amp_end = amplitude
     tot_duration = min(time[-1], stim_stop+1.0)  # 1sec beyond stim end
     hold_curr = 0.0
     return start_time, stim_stop, stim_amp_start, stim_amp_end, tot_duration, hold_curr
示例#8
0
def test_none_get_stimuli_characteristics():

    i = [0, 1, 1, 0, 0, 0, 0, 0]
    t = np.arange(len(i))
    start_time, duration, amplitude, start_idx, end_idx = st.get_stim_characteristics(i, t)
    assert start_time is None
    assert duration is None
    assert np.isclose(amplitude, 0)
    assert start_idx is None
    assert end_idx is None
示例#9
0
def preprocess_ramp_sweeps(data_set, sweep_numbers):
    if len(sweep_numbers) == 0:
        raise er.FeatureError("No ramp sweeps available for feature extraction")

    ramp_sweeps = data_set.sweep_set(sweep_numbers)
    ramp_sweeps.select_epoch("recording")

    ramp_start, ramp_dur, _, _, _ = stf.get_stim_characteristics(ramp_sweeps.sweeps[0].i, ramp_sweeps.sweeps[0].t)
    ramp_spx, ramp_spfx = dsf.extractors_for_sweeps(ramp_sweeps,
                                                start = ramp_start,
                                                **dsf.detection_parameters(data_set.RAMP))
    ramp_an = spa.RampAnalysis(ramp_spx, ramp_spfx)
    ramp_features = ramp_an.analyze(ramp_sweeps)

    return ramp_sweeps, ramp_features, ramp_an
示例#10
0
def plot_hero_figures(data_set, cell_features, lims_features, sweep_features,
                      image_dir, sizes, cell_image_files):
    fig = plt.figure()
    v, i, t, r, dt = load_sweep(data_set,
                                int(lims_features["thumbnail_sweep_num"]))
    plt.plot(t, v, color='black')
    stim_start, stim_dur, stim_amp, start_idx, end_idx = st.get_stim_characteristics(
        i, t)
    plt.xlim(stim_start - 0.05, stim_start + stim_dur + 0.05)
    plt.ylim(-110, 50)
    spike_times = [
        spk['threshold_t'] for spk in get_spikes(
            sweep_features, lims_features["thumbnail_sweep_num"])
    ]
    isis = np.diff(np.array(spike_times))
    plt.title("thumbnail {:d}, amp = {:.1f}".format(
        lims_features["thumbnail_sweep_num"], stim_amp))
    plt.tight_layout()

    save_figure(fig,
                'thumbnail_0',
                'thumbnail',
                image_dir,
                sizes,
                cell_image_files,
                scalew=2)

    fig = plt.figure()
    plt.plot(range(len(isis)), isis)
    plt.ylabel("ISI (ms)")
    if lims_features.get("adaptation", None) is not None:
        plt.title("adapt = {:.3g}".format(lims_features["adaptation"]))
    else:
        plt.title("adapt = not defined")

    plt.tight_layout()
    save_figure(fig, 'thumbnail_1', 'thumbnail', image_dir, sizes,
                cell_image_files)

    summary_fig = plot_long_square_summary(data_set, cell_features,
                                           lims_features)
    save_figure(summary_fig,
                'ephys_summary',
                'thumbnail',
                image_dir,
                sizes,
                cell_image_files,
                scalew=2)
示例#11
0
def sweeps_from_nwb(nwb_data, sweep_number_list):
    """ Generate a SweepSet object from an NWB reader and list of sweep numbers

    Sweeps should be in current-clamp mode.

    Parameters
    ----------
    nwb_data: NwbReader
    sweep_number_list: list
        List of sweep numbers

    Returns
    -------
    sweeps: SweepSet
    stim_start: float
        Start time of stimulus (seconds)
    stim_end: float
        End time of stimulus (seconds)
    """

    sweep_list = []
    start = None
    dur = None
    for sweep_number in sweep_number_list:
        sweep_data = nwb_data.get_sweep_data(sweep_number)
        sampling_rate = sweep_data["sampling_rate"]
        dt = 1.0 / sampling_rate
        t = np.arange(0, len(sweep_data["stimulus"])) * dt
        v = sweep_data["response"] * 1e3  # data from NWB now comes in Volts
        i = sweep_data["stimulus"] * 1e12  # data from NWB now comes in Amps
        sweep = Sweep(
            t=t,
            v=v,
            i=i,
            sampling_rate=sampling_rate,
            sweep_number=sweep_number,
            clamp_mode="CurrentClamp",
            epochs=None,
        )
        sweep_list.append(sweep)
        start, dur, _, _, _ = stf.get_stim_characteristics(i, t)
    if start is None or dur is None:
        return SweepSet(sweep_list), None, None
    else:
        return SweepSet(sweep_list), start, start + dur
示例#12
0
def plot_fi_curve_figures(data_set, cell_features, lims_features,
                          sweep_features, image_dir, sizes, cell_image_files):
    fig = plt.figure()
    fi_sorted = sorted(cell_features["long_squares"]["spiking_sweeps"],
                       key=lambda s: s['stim_amp'])
    x = [d['stim_amp'] for d in fi_sorted]
    y = [d['avg_rate'] for d in fi_sorted]
    first_nonzero_idx = np.nonzero(y)[0][0]
    plt.scatter(x, y, color='black')
    plt.plot(x[first_nonzero_idx:],
             cell_features["long_squares"]["fi_fit_slope"] *
             (np.array(x[first_nonzero_idx:]) - x[first_nonzero_idx]),
             color='red',
             linewidth=2)
    plt.xlabel("pA")
    plt.ylabel("spikes/sec")
    plt.title("slope = {:.3g}".format(lims_features["f_i_curve_slope"]))
    rheo_hero_sweeps = [
        int(lims_features["rheobase_sweep_num"]),
        int(lims_features["thumbnail_sweep_num"])
    ]
    rheo_hero_x = []
    for s in rheo_hero_sweeps:
        v, i, t, r, dt = load_sweep(data_set, s)
        stim_start, stim_dur, stim_amp, start_idx, end_idx = st.get_stim_characteristics(
            i, t)
        rheo_hero_x.append(stim_amp)
    rheo_hero_y = [
        len(get_spikes(sweep_features, s)) for s in rheo_hero_sweeps
    ]
    plt.scatter(rheo_hero_x, rheo_hero_y, zorder=20, color="blue")
    plt.tight_layout()

    save_figure(fig,
                'fi_curve',
                'fi_curve',
                image_dir,
                sizes,
                cell_image_files,
                scalew=2)
示例#13
0
def plot_instantaneous_threshold_thumbnail(data_set,
                                           sweep_numbers,
                                           cell_features,
                                           lims_features,
                                           sweep_features,
                                           color='red'):
    min_sweep_number = None
    for sn in sorted(sweep_numbers):
        spikes = get_spikes(sweep_features, sn)

        if len(spikes) > 0:
            min_sweep_number = sn if min_sweep_number is None else min(
                min_sweep_number, sn)

    fig = plt.figure(frameon=False)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    ax.set_yticklabels([])
    ax.set_xticklabels([])
    ax.set_xlabel('')
    ax.set_ylabel('')

    v, i, t, r, dt = load_sweep(data_set, sn)
    stim_start, stim_dur, stim_amp, start_idx, end_idx = st.get_stim_characteristics(
        i, t)

    tstart = stim_start - 0.002
    tend = stim_start + stim_dur + 0.005
    tscale = 0.005

    plt.plot(t, v, linewidth=1, color=color)

    plt.ylim(AXIS_Y_RANGE[0], AXIS_Y_RANGE[1])
    plt.xlim(tstart, tend)

    return fig
示例#14
0
def extract_features(data_set, ramp_sweep_numbers, ssq_sweep_numbers, lsq_sweep_numbers,
                     amp_interval=20, max_above_rheo=100):
    features = {}
    # RAMP FEATURES -----------------
    if len(ramp_sweep_numbers) > 0:
        ramp_sweeps = data_set.sweep_set(ramp_sweep_numbers)

        ramp_start, ramp_dur, _, _, _ = stf.get_stim_characteristics(ramp_sweeps.sweeps[0].i, ramp_sweeps.sweeps[0].t)
        ramp_spx, ramp_spfx = dsf.extractors_for_sweeps(ramp_sweeps,
                                                    start = ramp_start,
                                                    **dsf.detection_parameters(data_set.RAMP))
        ramp_an = spa.RampAnalysis(ramp_spx, ramp_spfx)
        basic_ramp_features = ramp_an.analyze(ramp_sweeps)
        first_spike_ramp_features = first_spike_ramp(ramp_an)
        features.update(first_spike_ramp_features)

    # SHORT SQUARE FEATURES -----------------
    if len(ssq_sweep_numbers) > 0:
        ssq_sweeps = data_set.sweep_set(ssq_sweep_numbers)

        ssq_start, ssq_dur, _, _, _ = stf.get_stim_characteristics(ssq_sweeps.sweeps[0].i, ssq_sweeps.sweeps[0].t)
        ssq_spx, ssq_spfx = dsf.extractors_for_sweeps(ssq_sweeps,
                                                      est_window = [ssq_start, ssq_start+0.001],
                                                      **dsf.detection_parameters(data_set.SHORT_SQUARE))
        ssq_an = spa.ShortSquareAnalysis(ssq_spx, ssq_spfx)
        basic_ssq_features = ssq_an.analyze(ssq_sweeps)
        first_spike_ssq_features = first_spike_ssq(ssq_an)
        first_spike_ssq_features["short_square_current"] = basic_ssq_features["stimulus_amplitude"]
        features.update(first_spike_ssq_features)

    # LONG SQUARE SUBTHRESHOLD FEATURES -----------------
    if len(lsq_sweep_numbers) > 0:
        check_lsq_sweeps = data_set.sweep_set(lsq_sweep_numbers)
        lsq_start, lsq_dur, _, _, _ = stf.get_stim_characteristics(check_lsq_sweeps.sweeps[0].i, check_lsq_sweeps.sweeps[0].t)

        # Check that all sweeps are long enough and not ended early
        extra_dur = 0.2
        good_lsq_sweep_numbers = [n for n, s in zip(lsq_sweep_numbers, check_lsq_sweeps.sweeps)
                                  if s.t[-1] >= lsq_start + lsq_dur + extra_dur and not np.all(s.v[tsu.find_time_index(s.t, lsq_start + lsq_dur)-100:tsu.find_time_index(s.t, lsq_start + lsq_dur)] == 0)]
        lsq_sweeps = data_set.sweep_set(good_lsq_sweep_numbers)

        lsq_spx, lsq_spfx = dsf.extractors_for_sweeps(lsq_sweeps,
                                                      start = lsq_start,
                                                      end = lsq_start + lsq_dur,
                                                      **dsf.detection_parameters(data_set.LONG_SQUARE))
        lsq_an = spa.LongSquareAnalysis(lsq_spx, lsq_spfx, subthresh_min_amp=-100.)
        basic_lsq_features = lsq_an.analyze(lsq_sweeps)
        features.update({
            "input_resistance": basic_lsq_features["input_resistance"],
            "tau": basic_lsq_features["tau"],
            "v_baseline": basic_lsq_features["v_baseline"],
            "sag_nearest_minus_100": basic_lsq_features["sag"],
            "sag_measured_at": basic_lsq_features["vm_for_sag"],
            "rheobase_i": int(basic_lsq_features["rheobase_i"]),
            "fi_linear_fit_slope": basic_lsq_features["fi_fit_slope"],
        })

        # TODO (maybe): port sag_from_ri code over

        # Identify suprathreshold set for analysis
        sweep_table = basic_lsq_features["spiking_sweeps"]
        mask_supra = sweep_table["stim_amp"] >= basic_lsq_features["rheobase_i"]
        sweep_indexes = fv._consolidated_long_square_indexes(sweep_table.loc[mask_supra, :])
        amps = np.rint(sweep_table.loc[sweep_indexes, "stim_amp"].values - basic_lsq_features["rheobase_i"])
        spike_data = np.array(basic_lsq_features["spikes_set"])

        for amp, swp_ind in zip(amps, sweep_indexes):
            if (amp % amp_interval != 0) or (amp > max_above_rheo) or (amp < 0):
                continue
            amp_label = int(amp / amp_interval)

            first_spike_lsq_sweep_features = first_spike_lsq(spike_data[swp_ind])
            features.update({"ap_1_{:s}_{:d}_long_square".format(f, amp_label): v
                             for f, v in first_spike_lsq_sweep_features.items()})

            mean_spike_lsq_sweep_features = mean_spike_lsq(spike_data[swp_ind])
            features.update({"ap_mean_{:s}_{:d}_long_square".format(f, amp_label): v
                             for f, v in mean_spike_lsq_sweep_features.items()})

            sweep_feature_list = [
                "first_isi",
                "avg_rate",
                "isi_cv",
                "latency",
                "median_isi",
                "adapt",
            ]

            features.update({"{:s}_{:d}_long_square".format(f, amp_label): sweep_table.at[swp_ind, f]
                             for f in sweep_feature_list})
            features["stimulus_amplitude_{:d}_long_square".format(amp_label)] = int(amp + basic_lsq_features["rheobase_i"])

        rates = sweep_table.loc[sweep_indexes, "avg_rate"].values
        features.update(fi_curve_fit(amps, rates))

    return features
示例#15
0
def plot_single_ap_values(data_set, sweep_numbers, lims_features,
                          sweep_features, cell_features, type_name):
    figs = [plt.figure() for f in range(3 + len(sweep_numbers))]

    v, i, t, r, dt = load_sweep(data_set, sweep_numbers[0])
    if type_name == "short_square" or type_name == "long_square":
        stim_start, stim_dur, stim_amp, start_idx, end_idx = st.get_stim_characteristics(
            i, t)
    elif type_name == "ramp":
        stim_start, _, _, start_idx, _ = st.get_stim_characteristics(i, t)

    gen_features = [
        "threshold", "peak", "trough", "fast_trough", "slow_trough"
    ]
    voltage_features = [
        "threshold_v", "peak_v", "trough_v", "fast_trough_v", "slow_trough_v"
    ]
    time_features = [
        "threshold_t", "peak_t", "trough_t", "fast_trough_t", "slow_trough_t"
    ]

    for sn in sweep_numbers:
        spikes = get_spikes(sweep_features, sn)

        if (len(spikes) < 1):
            logging.warning("no spikes in sweep %d" % sn)
            continue

        if type_name != "long_square":
            voltages = [spikes[0][f] for f in voltage_features]
            times = [spikes[0][f] for f in time_features]
        else:
            rheo_sn = cell_features["long_squares"]["rheobase_sweep"][
                "sweep_number"]
            rheo_spike = get_spikes(sweep_features, rheo_sn)[0]
            voltages = [rheo_spike[f] for f in voltage_features]
            times = [rheo_spike[f] for f in time_features]

        plt.figure(figs[0].number)
        plt.scatter(range(len(voltages)), voltages, color='gray')
        plt.tight_layout()

        plt.figure(figs[1].number)
        plt.scatter(range(len(times)), times, color='gray')
        plt.tight_layout()

        plt.figure(figs[2].number)
        plt.scatter([0], [spikes[0]['upstroke'] / (-spikes[0]['downstroke'])],
                    color='gray')
        plt.tight_layout()

    plt.figure(figs[0].number)

    yvals = [
        float(lims_features[k + "_v_" + type_name]) for k in gen_features
        if lims_features[k + "_v_" + type_name] is not None
    ]
    xvals = range(len(yvals))

    plt.scatter(xvals, yvals, color='blue', marker='_', s=40, zorder=100)
    plt.xticks(xvals, ['thr', 'pk', 'tr', 'ftr', 'str'])
    plt.title(type_name + ": voltages")

    plt.figure(figs[1].number)
    yvals = [
        float(lims_features[k + "_t_" + type_name]) for k in gen_features
        if lims_features[k + "_t_" + type_name] is not None
    ]
    xvals = range(len(yvals))
    plt.scatter(xvals, yvals, color='blue', marker='_', s=40, zorder=100)
    plt.xticks(xvals, ['thr', 'pk', 'tr', 'ftr', 'str'])
    plt.title(type_name + ": times")

    plt.figure(figs[2].number)
    if lims_features["upstroke_downstroke_ratio_" + type_name] is not None:
        plt.scatter(
            [0],
            [float(lims_features["upstroke_downstroke_ratio_" + type_name])],
            color='blue',
            marker='_',
            s=40,
            zorder=100)
    plt.xticks([])
    plt.title(type_name + ": up/down")

    for index, sn in enumerate(sweep_numbers):
        plt.figure(figs[3 + index].number)

        v, i, t, r, dt = load_sweep(data_set, sn)
        hz = 1. / dt
        expt_start_idx, _ = ep.get_experiment_epoch(i, hz)
        t = t - expt_start_idx * dt
        stim_start_shifted = stim_start - expt_start_idx * dt
        plt.plot(t, v, color='black')
        plt.title(str(sn))

        spikes = get_spikes(sweep_features, sn)
        nspikes = len(spikes)

        delta_v = 5.0

        if nspikes:
            if type_name != "long_square" and nspikes:

                voltages = [spikes[0][f] for f in voltage_features]
                times = [spikes[0][f] for f in time_features]
            else:
                rheo_sn = cell_features["long_squares"]["rheobase_sweep"][
                    "sweep_number"]
                rheo_spike = get_spikes(sweep_features, rheo_sn)[0]
                voltages = [rheo_spike[f] for f in voltage_features]
                times = [rheo_spike[f] for f in time_features]

            plt.scatter(times, voltages, color='red', zorder=20)

            plt.plot([
                spikes[0]['upstroke_t'] - 1e-3 *
                (delta_v / spikes[0]['upstroke']), spikes[0]['upstroke_t'] +
                1e-3 * (delta_v / spikes[0]['upstroke'])
            ], [
                spikes[0]['upstroke_v'] - delta_v,
                spikes[0]['upstroke_v'] + delta_v
            ],
                     color='red')

            if 'downstroke_t' in spikes[0]:
                plt.plot([
                    spikes[0]['downstroke_t'] - 1e-3 *
                    (delta_v / spikes[0]['downstroke']),
                    spikes[0]['downstroke_t'] + 1e-3 *
                    (delta_v / spikes[0]['downstroke'])
                ], [
                    spikes[0]['downstroke_v'] - delta_v,
                    spikes[0]['downstroke_v'] + delta_v
                ],
                         color='red')

        if type_name == "ramp":
            if nspikes:
                plt.xlim(spikes[0]["threshold_t"] - 0.002,
                         spikes[0]["fast_trough_t"] + 0.01)
        elif type_name == "short_square":
            plt.xlim(stim_start_shifted - 0.002,
                     stim_start_shifted + stim_dur + 0.01)
        elif type_name == "long_square":
            plt.xlim(times[0] - 0.002, times[-2] + 0.002)

        plt.tight_layout()

    return figs
示例#16
0
def plot_subthreshold_long_square_figures(data_set, cell_features,
                                          lims_features, sweep_features,
                                          image_dir, sizes, cell_image_files):
    lsq_sweeps = cell_features["long_squares"]["sweeps"]
    sub_sweeps = cell_features["long_squares"]["subthreshold_sweeps"]
    tau_sweeps = cell_features["long_squares"][
        "subthreshold_membrane_property_sweeps"]

    # 0a - Plot VI curve and linear fit, along with vrest
    x = np.array([s['stim_amp'] for s in sub_sweeps])
    y = np.array([s['peak_deflect'][0] for s in sub_sweeps])
    i = np.array([s['stim_amp'] for s in tau_sweeps])

    fig = plt.figure()
    plt.scatter(x, y, color='black')
    plt.plot([x.min(), x.max()],
             [lims_features["vrest"], lims_features["vrest"]],
             color="blue",
             linewidth=2)
    plt.plot(i,
             i * 1e-3 * lims_features["ri"] + lims_features["vrest"],
             color="red",
             linewidth=2)
    plt.xlabel("pA")
    plt.ylabel("mV")
    plt.title("ri = {:.1f}, vrest = {:.1f}".format(lims_features["ri"],
                                                   lims_features["vrest"]))
    plt.tight_layout()

    save_figure(fig, 'VI_curve', 'subthreshold_long_squares', image_dir, sizes,
                cell_image_files)

    # 0b - Plot tau curve and average
    fig = plt.figure()
    x = np.array([s['stim_amp'] for s in tau_sweeps])
    y = np.array([s['tau'] for s in tau_sweeps])
    plt.scatter(x, y, color='black')
    i = np.array([s['stim_amp'] for s in tau_sweeps])
    plt.plot([i.min(), i.max()], [
        cell_features["long_squares"]["tau"],
        cell_features["long_squares"]["tau"]
    ],
             color="red",
             linewidth=2)
    plt.xlabel("pA")
    ylim = plt.ylim()
    plt.ylim(0, ylim[1])
    plt.ylabel("tau (s)")
    plt.tight_layout()

    save_figure(fig, 'tau_curve', 'subthreshold_long_squares', image_dir,
                sizes, cell_image_files)

    subthresh_dict = {s['sweep_number']: s for s in tau_sweeps}

    # 0c - Plot the subthreshold squares
    tau_sweeps = [s['sweep_number'] for s in tau_sweeps]
    tau_figs = [plt.figure() for i in range(len(tau_sweeps))]

    for index, s in enumerate(tau_sweeps):
        v, i, t, r, dt = load_sweep(data_set, s)

        plt.figure(tau_figs[index].number)

        plt.plot(t, v, color="black")

        if index == 0:
            min_y, max_y = plt.ylim()
        else:
            ylims = plt.ylim()
            if min_y > ylims[0]:
                min_y = ylims[0]
            if max_y < ylims[1]:
                max_y = ylims[1]

        stim_start, stim_dur, stim_amp, start_idx, end_idx = st.get_stim_characteristics(
            i, t)
        plt.xlim(stim_start - 0.05, stim_start + stim_dur + 0.05)
        peak_idx = subthresh_dict[s]['peak_deflect'][1]
        peak_t = t[peak_idx]
        plt.scatter([peak_t], [subthresh_dict[s]['peak_deflect'][0]],
                    color='red',
                    zorder=10)
        popt = subf.fit_membrane_time_constant(t, v, stim_start, peak_t)
        plt.title(str(s))
        plt.plot(t[start_idx:peak_idx],
                 exp_curve(t[start_idx:peak_idx] - t[start_idx], *popt),
                 color='blue')
        plt.xlabel("s")

    for index, s in enumerate(tau_sweeps):
        plt.figure(tau_figs[index].number)
        plt.ylim(min_y, max_y)
        plt.tight_layout()

    for index, tau_fig in enumerate(tau_figs):
        save_figure(tau_figs[index], 'tau_%d' % index,
                    'subthreshold_long_squares', image_dir, sizes,
                    cell_image_files)