def validate_sweeps(data_set, sweep_numbers, extra_dur=0.2): check_sweeps = data_set.sweep_set(sweep_numbers) valid_sweep_stim = [] start = None dur = None for swp in check_sweeps.sweeps: swp_start, swp_dur, _, _, _ = stf.get_stim_characteristics( swp.i, swp.t, False) if swp_start is None: valid_sweep_stim.append(False) else: start = swp_start dur = swp_dur valid_sweep_stim.append(True) if start is None: # Could not find any sweeps to define stimulus interval return [], None, None end = start + dur # Check that all sweeps are long enough and not ended early good_sweep_numbers = [ n for n, s, v in zip(sweep_numbers, check_sweeps.sweeps, valid_sweep_stim) if s.t[-1] >= end + extra_dur and v is True and not np.all( s.v[tsu.find_time_index(s.t, end) - 100:tsu.find_time_index(s.t, end)] == 0) ] return good_sweep_numbers, start, end
def feature_vector_input(): TEST_DATA_PATH = os.path.join(os.path.dirname(__file__), 'data') nwb_file_name = "Pvalb-IRES-Cre;Ai14-415796.02.01.01.nwb" nwb_file_full_path = os.path.join(TEST_DATA_PATH, nwb_file_name) if not os.path.exists(nwb_file_full_path): download_file(nwb_file_name, nwb_file_full_path) data_set = AibsDataSet(nwb_file=nwb_file_full_path, ontology=ontology) lsq_sweep_numbers = data_set.filtered_sweep_table( clamp_mode=data_set.CURRENT_CLAMP, stimuli=ontology.long_square_names).sweep_number.sort_values().values lsq_sweeps = data_set.sweep_set(lsq_sweep_numbers) lsq_start, lsq_dur, _, _, _ = stf.get_stim_characteristics( lsq_sweeps.sweeps[0].i, lsq_sweeps.sweeps[0].t) lsq_end = lsq_start + lsq_dur lsq_spx, lsq_spfx = dsf.extractors_for_sweeps(lsq_sweeps, start=lsq_start, end=lsq_end, **dsf.detection_parameters( data_set.LONG_SQUARE)) lsq_an = spa.LongSquareAnalysis(lsq_spx, lsq_spfx, subthresh_min_amp=-100.) lsq_features = lsq_an.analyze(lsq_sweeps) return lsq_sweeps, lsq_features, lsq_start, lsq_end
def plot_sag_figures(data_set, cell_features, lims_features, sweep_features, image_dir, sizes, cell_image_files): fig = plt.figure() for d in cell_features["long_squares"]["subthreshold_sweeps"]: if d['peak_deflect'][0] == lims_features["vm_for_sag"]: v, i, t, r, dt = load_sweep(data_set, int(d['sweep_number'])) stim_start, stim_dur, stim_amp, start_idx, end_idx = st.get_stim_characteristics( i, t) plt.plot(t, v, color='black') plt.scatter(d['peak_deflect'][1], d['peak_deflect'][0], color='red', zorder=10) #plt.plot([stim_start + stim_dur - 0.1, stim_start + stim_dur], [d['steady'], d['steady']], color='red', zorder=10) plt.xlim(stim_start - 0.25, stim_start + stim_dur + 0.25) plt.title("sag = {:.3g}".format(lims_features['sag'])) plt.tight_layout() save_figure(fig, 'sag', 'sag', image_dir, sizes, cell_image_files, scalew=2)
def feature_vector_input(): TEST_DATA_PATH = os.path.join(os.path.dirname(__file__), 'data') nwb_file_name = "Pvalb-IRES-Cre;Ai14-415796.02.01.01.nwb" nwb_file_full_path = os.path.join(TEST_DATA_PATH, nwb_file_name) if not os.path.exists(nwb_file_full_path): download_file(nwb_file_name, nwb_file_full_path) data_set = AibsDataSet(nwb_file=nwb_file_full_path, ontology=ontology) lsq_sweep_numbers = [4, 5, 6, 16, 17, 18, 19, 20, 21] lsq_sweeps = data_set.sweep_set(lsq_sweep_numbers) lsq_sweeps.select_epoch("recording") lsq_sweeps.align_to_start_of_epoch("experiment") lsq_start, lsq_dur, _, _, _ = stf.get_stim_characteristics(lsq_sweeps.sweeps[0].i, lsq_sweeps.sweeps[0].t) lsq_end = lsq_start + lsq_dur lsq_spx, lsq_spfx = dsf.extractors_for_sweeps(lsq_sweeps, start=lsq_start, end=lsq_end, **dsf.detection_parameters(data_set.LONG_SQUARE)) lsq_an = spa.LongSquareAnalysis(lsq_spx, lsq_spfx, subthresh_min_amp=-100.) lsq_features = lsq_an.analyze(lsq_sweeps) return lsq_sweeps, lsq_features, lsq_start, lsq_end
def plot_sweep_set_summary(data_set, highlight_sweep_number, sweep_numbers, highlight_color='#0779BE', background_color='#dddddd'): fig = plt.figure(frameon=False) ax = plt.Axes(fig, [0., 0., 1., 1.]) ax.set_axis_off() fig.add_axes(ax) ax.set_yticklabels([]) ax.set_xticklabels([]) ax.set_xlabel('') ax.set_ylabel('') for sn in sweep_numbers: v, i, t, r, dt = load_sweep(data_set, sn) ax.plot(t, v, linewidth=0.5, color=background_color) v, i, t, r, dt = load_sweep(data_set, highlight_sweep_number) plt.plot(t, v, linewidth=1, color=highlight_color) stim_start, stim_dur, stim_amp, start_idx, end_idx = st.get_stim_characteristics( i, t) tstart = stim_start - 0.05 tend = stim_start + stim_dur + 0.25 ax.set_ylim(AXIS_Y_RANGE[0], AXIS_Y_RANGE[1]) ax.set_xlim(tstart, tend) return fig
def test_get_stimuli_characteristics(i, test_pulse, stim_characteristics): t = np.arange(len(i)) start_time, duration, amplitude, start_idx, end_idx = st.get_stim_characteristics(i, t, test_pulse=test_pulse) assert start_time == stim_characteristics[0] assert np.isclose(duration, stim_characteristics[1]) assert np.isclose(amplitude, stim_characteristics[2]) assert start_idx == stim_characteristics[3] assert end_idx == stim_characteristics[4]
def calc_stimparams_ipfx(time, stimulus_trace, trace_name): start_time, duration, amplitude, start_idx, end_idx = get_stim_characteristics( stimulus_trace, time) amplitude *= 1e12 stim_stop = start_time + duration stim_amp_start = 1e12 * stimulus_trace[start_idx] stim_amp_end = amplitude tot_duration = min(time[-1], stim_stop+1.0) # 1sec beyond stim end hold_curr = 0.0 return start_time, stim_stop, stim_amp_start, stim_amp_end, tot_duration, hold_curr
def test_none_get_stimuli_characteristics(): i = [0, 1, 1, 0, 0, 0, 0, 0] t = np.arange(len(i)) start_time, duration, amplitude, start_idx, end_idx = st.get_stim_characteristics(i, t) assert start_time is None assert duration is None assert np.isclose(amplitude, 0) assert start_idx is None assert end_idx is None
def preprocess_ramp_sweeps(data_set, sweep_numbers): if len(sweep_numbers) == 0: raise er.FeatureError("No ramp sweeps available for feature extraction") ramp_sweeps = data_set.sweep_set(sweep_numbers) ramp_sweeps.select_epoch("recording") ramp_start, ramp_dur, _, _, _ = stf.get_stim_characteristics(ramp_sweeps.sweeps[0].i, ramp_sweeps.sweeps[0].t) ramp_spx, ramp_spfx = dsf.extractors_for_sweeps(ramp_sweeps, start = ramp_start, **dsf.detection_parameters(data_set.RAMP)) ramp_an = spa.RampAnalysis(ramp_spx, ramp_spfx) ramp_features = ramp_an.analyze(ramp_sweeps) return ramp_sweeps, ramp_features, ramp_an
def plot_hero_figures(data_set, cell_features, lims_features, sweep_features, image_dir, sizes, cell_image_files): fig = plt.figure() v, i, t, r, dt = load_sweep(data_set, int(lims_features["thumbnail_sweep_num"])) plt.plot(t, v, color='black') stim_start, stim_dur, stim_amp, start_idx, end_idx = st.get_stim_characteristics( i, t) plt.xlim(stim_start - 0.05, stim_start + stim_dur + 0.05) plt.ylim(-110, 50) spike_times = [ spk['threshold_t'] for spk in get_spikes( sweep_features, lims_features["thumbnail_sweep_num"]) ] isis = np.diff(np.array(spike_times)) plt.title("thumbnail {:d}, amp = {:.1f}".format( lims_features["thumbnail_sweep_num"], stim_amp)) plt.tight_layout() save_figure(fig, 'thumbnail_0', 'thumbnail', image_dir, sizes, cell_image_files, scalew=2) fig = plt.figure() plt.plot(range(len(isis)), isis) plt.ylabel("ISI (ms)") if lims_features.get("adaptation", None) is not None: plt.title("adapt = {:.3g}".format(lims_features["adaptation"])) else: plt.title("adapt = not defined") plt.tight_layout() save_figure(fig, 'thumbnail_1', 'thumbnail', image_dir, sizes, cell_image_files) summary_fig = plot_long_square_summary(data_set, cell_features, lims_features) save_figure(summary_fig, 'ephys_summary', 'thumbnail', image_dir, sizes, cell_image_files, scalew=2)
def sweeps_from_nwb(nwb_data, sweep_number_list): """ Generate a SweepSet object from an NWB reader and list of sweep numbers Sweeps should be in current-clamp mode. Parameters ---------- nwb_data: NwbReader sweep_number_list: list List of sweep numbers Returns ------- sweeps: SweepSet stim_start: float Start time of stimulus (seconds) stim_end: float End time of stimulus (seconds) """ sweep_list = [] start = None dur = None for sweep_number in sweep_number_list: sweep_data = nwb_data.get_sweep_data(sweep_number) sampling_rate = sweep_data["sampling_rate"] dt = 1.0 / sampling_rate t = np.arange(0, len(sweep_data["stimulus"])) * dt v = sweep_data["response"] * 1e3 # data from NWB now comes in Volts i = sweep_data["stimulus"] * 1e12 # data from NWB now comes in Amps sweep = Sweep( t=t, v=v, i=i, sampling_rate=sampling_rate, sweep_number=sweep_number, clamp_mode="CurrentClamp", epochs=None, ) sweep_list.append(sweep) start, dur, _, _, _ = stf.get_stim_characteristics(i, t) if start is None or dur is None: return SweepSet(sweep_list), None, None else: return SweepSet(sweep_list), start, start + dur
def plot_fi_curve_figures(data_set, cell_features, lims_features, sweep_features, image_dir, sizes, cell_image_files): fig = plt.figure() fi_sorted = sorted(cell_features["long_squares"]["spiking_sweeps"], key=lambda s: s['stim_amp']) x = [d['stim_amp'] for d in fi_sorted] y = [d['avg_rate'] for d in fi_sorted] first_nonzero_idx = np.nonzero(y)[0][0] plt.scatter(x, y, color='black') plt.plot(x[first_nonzero_idx:], cell_features["long_squares"]["fi_fit_slope"] * (np.array(x[first_nonzero_idx:]) - x[first_nonzero_idx]), color='red', linewidth=2) plt.xlabel("pA") plt.ylabel("spikes/sec") plt.title("slope = {:.3g}".format(lims_features["f_i_curve_slope"])) rheo_hero_sweeps = [ int(lims_features["rheobase_sweep_num"]), int(lims_features["thumbnail_sweep_num"]) ] rheo_hero_x = [] for s in rheo_hero_sweeps: v, i, t, r, dt = load_sweep(data_set, s) stim_start, stim_dur, stim_amp, start_idx, end_idx = st.get_stim_characteristics( i, t) rheo_hero_x.append(stim_amp) rheo_hero_y = [ len(get_spikes(sweep_features, s)) for s in rheo_hero_sweeps ] plt.scatter(rheo_hero_x, rheo_hero_y, zorder=20, color="blue") plt.tight_layout() save_figure(fig, 'fi_curve', 'fi_curve', image_dir, sizes, cell_image_files, scalew=2)
def plot_instantaneous_threshold_thumbnail(data_set, sweep_numbers, cell_features, lims_features, sweep_features, color='red'): min_sweep_number = None for sn in sorted(sweep_numbers): spikes = get_spikes(sweep_features, sn) if len(spikes) > 0: min_sweep_number = sn if min_sweep_number is None else min( min_sweep_number, sn) fig = plt.figure(frameon=False) ax = plt.Axes(fig, [0., 0., 1., 1.]) ax.set_axis_off() fig.add_axes(ax) ax.set_yticklabels([]) ax.set_xticklabels([]) ax.set_xlabel('') ax.set_ylabel('') v, i, t, r, dt = load_sweep(data_set, sn) stim_start, stim_dur, stim_amp, start_idx, end_idx = st.get_stim_characteristics( i, t) tstart = stim_start - 0.002 tend = stim_start + stim_dur + 0.005 tscale = 0.005 plt.plot(t, v, linewidth=1, color=color) plt.ylim(AXIS_Y_RANGE[0], AXIS_Y_RANGE[1]) plt.xlim(tstart, tend) return fig
def extract_features(data_set, ramp_sweep_numbers, ssq_sweep_numbers, lsq_sweep_numbers, amp_interval=20, max_above_rheo=100): features = {} # RAMP FEATURES ----------------- if len(ramp_sweep_numbers) > 0: ramp_sweeps = data_set.sweep_set(ramp_sweep_numbers) ramp_start, ramp_dur, _, _, _ = stf.get_stim_characteristics(ramp_sweeps.sweeps[0].i, ramp_sweeps.sweeps[0].t) ramp_spx, ramp_spfx = dsf.extractors_for_sweeps(ramp_sweeps, start = ramp_start, **dsf.detection_parameters(data_set.RAMP)) ramp_an = spa.RampAnalysis(ramp_spx, ramp_spfx) basic_ramp_features = ramp_an.analyze(ramp_sweeps) first_spike_ramp_features = first_spike_ramp(ramp_an) features.update(first_spike_ramp_features) # SHORT SQUARE FEATURES ----------------- if len(ssq_sweep_numbers) > 0: ssq_sweeps = data_set.sweep_set(ssq_sweep_numbers) ssq_start, ssq_dur, _, _, _ = stf.get_stim_characteristics(ssq_sweeps.sweeps[0].i, ssq_sweeps.sweeps[0].t) ssq_spx, ssq_spfx = dsf.extractors_for_sweeps(ssq_sweeps, est_window = [ssq_start, ssq_start+0.001], **dsf.detection_parameters(data_set.SHORT_SQUARE)) ssq_an = spa.ShortSquareAnalysis(ssq_spx, ssq_spfx) basic_ssq_features = ssq_an.analyze(ssq_sweeps) first_spike_ssq_features = first_spike_ssq(ssq_an) first_spike_ssq_features["short_square_current"] = basic_ssq_features["stimulus_amplitude"] features.update(first_spike_ssq_features) # LONG SQUARE SUBTHRESHOLD FEATURES ----------------- if len(lsq_sweep_numbers) > 0: check_lsq_sweeps = data_set.sweep_set(lsq_sweep_numbers) lsq_start, lsq_dur, _, _, _ = stf.get_stim_characteristics(check_lsq_sweeps.sweeps[0].i, check_lsq_sweeps.sweeps[0].t) # Check that all sweeps are long enough and not ended early extra_dur = 0.2 good_lsq_sweep_numbers = [n for n, s in zip(lsq_sweep_numbers, check_lsq_sweeps.sweeps) if s.t[-1] >= lsq_start + lsq_dur + extra_dur and not np.all(s.v[tsu.find_time_index(s.t, lsq_start + lsq_dur)-100:tsu.find_time_index(s.t, lsq_start + lsq_dur)] == 0)] lsq_sweeps = data_set.sweep_set(good_lsq_sweep_numbers) lsq_spx, lsq_spfx = dsf.extractors_for_sweeps(lsq_sweeps, start = lsq_start, end = lsq_start + lsq_dur, **dsf.detection_parameters(data_set.LONG_SQUARE)) lsq_an = spa.LongSquareAnalysis(lsq_spx, lsq_spfx, subthresh_min_amp=-100.) basic_lsq_features = lsq_an.analyze(lsq_sweeps) features.update({ "input_resistance": basic_lsq_features["input_resistance"], "tau": basic_lsq_features["tau"], "v_baseline": basic_lsq_features["v_baseline"], "sag_nearest_minus_100": basic_lsq_features["sag"], "sag_measured_at": basic_lsq_features["vm_for_sag"], "rheobase_i": int(basic_lsq_features["rheobase_i"]), "fi_linear_fit_slope": basic_lsq_features["fi_fit_slope"], }) # TODO (maybe): port sag_from_ri code over # Identify suprathreshold set for analysis sweep_table = basic_lsq_features["spiking_sweeps"] mask_supra = sweep_table["stim_amp"] >= basic_lsq_features["rheobase_i"] sweep_indexes = fv._consolidated_long_square_indexes(sweep_table.loc[mask_supra, :]) amps = np.rint(sweep_table.loc[sweep_indexes, "stim_amp"].values - basic_lsq_features["rheobase_i"]) spike_data = np.array(basic_lsq_features["spikes_set"]) for amp, swp_ind in zip(amps, sweep_indexes): if (amp % amp_interval != 0) or (amp > max_above_rheo) or (amp < 0): continue amp_label = int(amp / amp_interval) first_spike_lsq_sweep_features = first_spike_lsq(spike_data[swp_ind]) features.update({"ap_1_{:s}_{:d}_long_square".format(f, amp_label): v for f, v in first_spike_lsq_sweep_features.items()}) mean_spike_lsq_sweep_features = mean_spike_lsq(spike_data[swp_ind]) features.update({"ap_mean_{:s}_{:d}_long_square".format(f, amp_label): v for f, v in mean_spike_lsq_sweep_features.items()}) sweep_feature_list = [ "first_isi", "avg_rate", "isi_cv", "latency", "median_isi", "adapt", ] features.update({"{:s}_{:d}_long_square".format(f, amp_label): sweep_table.at[swp_ind, f] for f in sweep_feature_list}) features["stimulus_amplitude_{:d}_long_square".format(amp_label)] = int(amp + basic_lsq_features["rheobase_i"]) rates = sweep_table.loc[sweep_indexes, "avg_rate"].values features.update(fi_curve_fit(amps, rates)) return features
def plot_single_ap_values(data_set, sweep_numbers, lims_features, sweep_features, cell_features, type_name): figs = [plt.figure() for f in range(3 + len(sweep_numbers))] v, i, t, r, dt = load_sweep(data_set, sweep_numbers[0]) if type_name == "short_square" or type_name == "long_square": stim_start, stim_dur, stim_amp, start_idx, end_idx = st.get_stim_characteristics( i, t) elif type_name == "ramp": stim_start, _, _, start_idx, _ = st.get_stim_characteristics(i, t) gen_features = [ "threshold", "peak", "trough", "fast_trough", "slow_trough" ] voltage_features = [ "threshold_v", "peak_v", "trough_v", "fast_trough_v", "slow_trough_v" ] time_features = [ "threshold_t", "peak_t", "trough_t", "fast_trough_t", "slow_trough_t" ] for sn in sweep_numbers: spikes = get_spikes(sweep_features, sn) if (len(spikes) < 1): logging.warning("no spikes in sweep %d" % sn) continue if type_name != "long_square": voltages = [spikes[0][f] for f in voltage_features] times = [spikes[0][f] for f in time_features] else: rheo_sn = cell_features["long_squares"]["rheobase_sweep"][ "sweep_number"] rheo_spike = get_spikes(sweep_features, rheo_sn)[0] voltages = [rheo_spike[f] for f in voltage_features] times = [rheo_spike[f] for f in time_features] plt.figure(figs[0].number) plt.scatter(range(len(voltages)), voltages, color='gray') plt.tight_layout() plt.figure(figs[1].number) plt.scatter(range(len(times)), times, color='gray') plt.tight_layout() plt.figure(figs[2].number) plt.scatter([0], [spikes[0]['upstroke'] / (-spikes[0]['downstroke'])], color='gray') plt.tight_layout() plt.figure(figs[0].number) yvals = [ float(lims_features[k + "_v_" + type_name]) for k in gen_features if lims_features[k + "_v_" + type_name] is not None ] xvals = range(len(yvals)) plt.scatter(xvals, yvals, color='blue', marker='_', s=40, zorder=100) plt.xticks(xvals, ['thr', 'pk', 'tr', 'ftr', 'str']) plt.title(type_name + ": voltages") plt.figure(figs[1].number) yvals = [ float(lims_features[k + "_t_" + type_name]) for k in gen_features if lims_features[k + "_t_" + type_name] is not None ] xvals = range(len(yvals)) plt.scatter(xvals, yvals, color='blue', marker='_', s=40, zorder=100) plt.xticks(xvals, ['thr', 'pk', 'tr', 'ftr', 'str']) plt.title(type_name + ": times") plt.figure(figs[2].number) if lims_features["upstroke_downstroke_ratio_" + type_name] is not None: plt.scatter( [0], [float(lims_features["upstroke_downstroke_ratio_" + type_name])], color='blue', marker='_', s=40, zorder=100) plt.xticks([]) plt.title(type_name + ": up/down") for index, sn in enumerate(sweep_numbers): plt.figure(figs[3 + index].number) v, i, t, r, dt = load_sweep(data_set, sn) hz = 1. / dt expt_start_idx, _ = ep.get_experiment_epoch(i, hz) t = t - expt_start_idx * dt stim_start_shifted = stim_start - expt_start_idx * dt plt.plot(t, v, color='black') plt.title(str(sn)) spikes = get_spikes(sweep_features, sn) nspikes = len(spikes) delta_v = 5.0 if nspikes: if type_name != "long_square" and nspikes: voltages = [spikes[0][f] for f in voltage_features] times = [spikes[0][f] for f in time_features] else: rheo_sn = cell_features["long_squares"]["rheobase_sweep"][ "sweep_number"] rheo_spike = get_spikes(sweep_features, rheo_sn)[0] voltages = [rheo_spike[f] for f in voltage_features] times = [rheo_spike[f] for f in time_features] plt.scatter(times, voltages, color='red', zorder=20) plt.plot([ spikes[0]['upstroke_t'] - 1e-3 * (delta_v / spikes[0]['upstroke']), spikes[0]['upstroke_t'] + 1e-3 * (delta_v / spikes[0]['upstroke']) ], [ spikes[0]['upstroke_v'] - delta_v, spikes[0]['upstroke_v'] + delta_v ], color='red') if 'downstroke_t' in spikes[0]: plt.plot([ spikes[0]['downstroke_t'] - 1e-3 * (delta_v / spikes[0]['downstroke']), spikes[0]['downstroke_t'] + 1e-3 * (delta_v / spikes[0]['downstroke']) ], [ spikes[0]['downstroke_v'] - delta_v, spikes[0]['downstroke_v'] + delta_v ], color='red') if type_name == "ramp": if nspikes: plt.xlim(spikes[0]["threshold_t"] - 0.002, spikes[0]["fast_trough_t"] + 0.01) elif type_name == "short_square": plt.xlim(stim_start_shifted - 0.002, stim_start_shifted + stim_dur + 0.01) elif type_name == "long_square": plt.xlim(times[0] - 0.002, times[-2] + 0.002) plt.tight_layout() return figs
def plot_subthreshold_long_square_figures(data_set, cell_features, lims_features, sweep_features, image_dir, sizes, cell_image_files): lsq_sweeps = cell_features["long_squares"]["sweeps"] sub_sweeps = cell_features["long_squares"]["subthreshold_sweeps"] tau_sweeps = cell_features["long_squares"][ "subthreshold_membrane_property_sweeps"] # 0a - Plot VI curve and linear fit, along with vrest x = np.array([s['stim_amp'] for s in sub_sweeps]) y = np.array([s['peak_deflect'][0] for s in sub_sweeps]) i = np.array([s['stim_amp'] for s in tau_sweeps]) fig = plt.figure() plt.scatter(x, y, color='black') plt.plot([x.min(), x.max()], [lims_features["vrest"], lims_features["vrest"]], color="blue", linewidth=2) plt.plot(i, i * 1e-3 * lims_features["ri"] + lims_features["vrest"], color="red", linewidth=2) plt.xlabel("pA") plt.ylabel("mV") plt.title("ri = {:.1f}, vrest = {:.1f}".format(lims_features["ri"], lims_features["vrest"])) plt.tight_layout() save_figure(fig, 'VI_curve', 'subthreshold_long_squares', image_dir, sizes, cell_image_files) # 0b - Plot tau curve and average fig = plt.figure() x = np.array([s['stim_amp'] for s in tau_sweeps]) y = np.array([s['tau'] for s in tau_sweeps]) plt.scatter(x, y, color='black') i = np.array([s['stim_amp'] for s in tau_sweeps]) plt.plot([i.min(), i.max()], [ cell_features["long_squares"]["tau"], cell_features["long_squares"]["tau"] ], color="red", linewidth=2) plt.xlabel("pA") ylim = plt.ylim() plt.ylim(0, ylim[1]) plt.ylabel("tau (s)") plt.tight_layout() save_figure(fig, 'tau_curve', 'subthreshold_long_squares', image_dir, sizes, cell_image_files) subthresh_dict = {s['sweep_number']: s for s in tau_sweeps} # 0c - Plot the subthreshold squares tau_sweeps = [s['sweep_number'] for s in tau_sweeps] tau_figs = [plt.figure() for i in range(len(tau_sweeps))] for index, s in enumerate(tau_sweeps): v, i, t, r, dt = load_sweep(data_set, s) plt.figure(tau_figs[index].number) plt.plot(t, v, color="black") if index == 0: min_y, max_y = plt.ylim() else: ylims = plt.ylim() if min_y > ylims[0]: min_y = ylims[0] if max_y < ylims[1]: max_y = ylims[1] stim_start, stim_dur, stim_amp, start_idx, end_idx = st.get_stim_characteristics( i, t) plt.xlim(stim_start - 0.05, stim_start + stim_dur + 0.05) peak_idx = subthresh_dict[s]['peak_deflect'][1] peak_t = t[peak_idx] plt.scatter([peak_t], [subthresh_dict[s]['peak_deflect'][0]], color='red', zorder=10) popt = subf.fit_membrane_time_constant(t, v, stim_start, peak_t) plt.title(str(s)) plt.plot(t[start_idx:peak_idx], exp_curve(t[start_idx:peak_idx] - t[start_idx], *popt), color='blue') plt.xlabel("s") for index, s in enumerate(tau_sweeps): plt.figure(tau_figs[index].number) plt.ylim(min_y, max_y) plt.tight_layout() for index, tau_fig in enumerate(tau_figs): save_figure(tau_figs[index], 'tau_%d' % index, 'subthreshold_long_squares', image_dir, sizes, cell_image_files)