Beispiel #1
0
    def detect_epochs(self):
        """
        Detect epochs if they are not provided in the constructor

        """

        if "test" not in self.epochs:
            self.epochs["test"] = ep.get_test_epoch(self.stimulus,
                                                    self.sampling_rate)
        if self.epochs["test"]:
            test_pulse = True
        else:
            test_pulse = False

        epoch_detectors = {
            "sweep":
            ep.get_sweep_epoch(self.response),
            "recording":
            ep.get_recording_epoch(self.response),
            "experiment":
            ep.get_experiment_epoch(self._i, self.sampling_rate, test_pulse),
            "stim":
            ep.get_stim_epoch(self.stimulus, test_pulse),
        }

        for epoch_name, epoch_detector in epoch_detectors.items():
            if epoch_name not in self.epochs:
                self.epochs[epoch_name] = epoch_detector
def experiment_plot_data(
        sweep: Sweep,
        backup_start_index: int = 5000,
        baseline_start_index: int = 5000,
        baseline_end_index: int = 9000
) -> Tuple[np.ndarray, np.ndarray, float]:
    """ Extract the data required for plotting a single sweep's experiment 
    epoch.

    Parameters
    ----------
    sweep : contains data to be extracted
    backup_start_index : if the start index of this sweep's experiment epoch
        cannot be programatically assessed, fall back to this.
    baseline_start_index : Start accumulating baseline samples from this index
    baseline_end_index : Stop accumulating baseline samples at this index

    Returns
    -------
    time : timestamps (s) of voltage samples for this sweep
    voltage : in (mV). The voltage trace for this sweep's experiment epoch.
    baseline_mean : the average voltage (mV) during the baseline epoch for this 
        sweep

    """

    experiment_start_index, experiment_end_index = \
        get_experiment_epoch(sweep.i, sweep.sampling_rate) \
            or (backup_start_index, len(sweep.i))

    if experiment_start_index <= 0:
        experiment_start_index = backup_start_index

    time = sweep.t[experiment_start_index:experiment_end_index]
    voltage = sweep.v[experiment_start_index:experiment_end_index]

    voltage[np.isnan(voltage)] = 0.0

    baseline_mean = np.nanmean(
        voltage[baseline_start_index:baseline_end_index])
    return time, voltage, baseline_mean
Beispiel #3
0
    def detect_epochs(self):
        """
        Detect epochs if they are not provided in the constructor

        """

        if "test" not in self.epochs:
            self.epochs["test"] = ep.get_test_epoch(self._stimulus, self.sampling_rate)
        if self.epochs["test"]:
            test_pulse = True
        else:
            test_pulse = False

        if "sweep" not in self.epochs:
            self.epochs["sweep"] = ep.get_sweep_epoch(self._i)
        if "recording" not in self.epochs:
            self.epochs["recording"] = ep.get_recording_epoch(self._response)
        # get valid recording by selecting epoch and using i/v prop before detecting stim
        self.select_epoch("recording")
        stim = self.i if self.clamp_mode == "CurrentClamp" else self.v
        if "stim" not in self.epochs:
            self.epochs["stim"] = ep.get_stim_epoch(stim, test_pulse)
        if "experiment" not in self.epochs:
            self.epochs["experiment"] = ep.get_experiment_epoch(stim, self.sampling_rate, test_pulse)
Beispiel #4
0
def test_get_experiment_epoch(i, sampling_rate, expt_epoch):
    assert expt_epoch == ep.get_experiment_epoch(i, sampling_rate)
Beispiel #5
0
def plot_single_ap_values(data_set, sweep_numbers, lims_features,
                          sweep_features, cell_features, type_name):
    figs = [plt.figure() for f in range(3 + len(sweep_numbers))]

    v, i, t, r, dt = load_sweep(data_set, sweep_numbers[0])
    if type_name == "short_square" or type_name == "long_square":
        stim_start, stim_dur, stim_amp, start_idx, end_idx = st.get_stim_characteristics(
            i, t)
    elif type_name == "ramp":
        stim_start, _, _, start_idx, _ = st.get_stim_characteristics(i, t)

    gen_features = [
        "threshold", "peak", "trough", "fast_trough", "slow_trough"
    ]
    voltage_features = [
        "threshold_v", "peak_v", "trough_v", "fast_trough_v", "slow_trough_v"
    ]
    time_features = [
        "threshold_t", "peak_t", "trough_t", "fast_trough_t", "slow_trough_t"
    ]

    for sn in sweep_numbers:
        spikes = get_spikes(sweep_features, sn)

        if (len(spikes) < 1):
            logging.warning("no spikes in sweep %d" % sn)
            continue

        if type_name != "long_square":
            voltages = [spikes[0][f] for f in voltage_features]
            times = [spikes[0][f] for f in time_features]
        else:
            rheo_sn = cell_features["long_squares"]["rheobase_sweep"][
                "sweep_number"]
            rheo_spike = get_spikes(sweep_features, rheo_sn)[0]
            voltages = [rheo_spike[f] for f in voltage_features]
            times = [rheo_spike[f] for f in time_features]

        plt.figure(figs[0].number)
        plt.scatter(range(len(voltages)), voltages, color='gray')
        plt.tight_layout()

        plt.figure(figs[1].number)
        plt.scatter(range(len(times)), times, color='gray')
        plt.tight_layout()

        plt.figure(figs[2].number)
        plt.scatter([0], [spikes[0]['upstroke'] / (-spikes[0]['downstroke'])],
                    color='gray')
        plt.tight_layout()

    plt.figure(figs[0].number)

    yvals = [
        float(lims_features[k + "_v_" + type_name]) for k in gen_features
        if lims_features[k + "_v_" + type_name] is not None
    ]
    xvals = range(len(yvals))

    plt.scatter(xvals, yvals, color='blue', marker='_', s=40, zorder=100)
    plt.xticks(xvals, ['thr', 'pk', 'tr', 'ftr', 'str'])
    plt.title(type_name + ": voltages")

    plt.figure(figs[1].number)
    yvals = [
        float(lims_features[k + "_t_" + type_name]) for k in gen_features
        if lims_features[k + "_t_" + type_name] is not None
    ]
    xvals = range(len(yvals))
    plt.scatter(xvals, yvals, color='blue', marker='_', s=40, zorder=100)
    plt.xticks(xvals, ['thr', 'pk', 'tr', 'ftr', 'str'])
    plt.title(type_name + ": times")

    plt.figure(figs[2].number)
    if lims_features["upstroke_downstroke_ratio_" + type_name] is not None:
        plt.scatter(
            [0],
            [float(lims_features["upstroke_downstroke_ratio_" + type_name])],
            color='blue',
            marker='_',
            s=40,
            zorder=100)
    plt.xticks([])
    plt.title(type_name + ": up/down")

    for index, sn in enumerate(sweep_numbers):
        plt.figure(figs[3 + index].number)

        v, i, t, r, dt = load_sweep(data_set, sn)
        hz = 1. / dt
        expt_start_idx, _ = ep.get_experiment_epoch(i, hz)
        t = t - expt_start_idx * dt
        stim_start_shifted = stim_start - expt_start_idx * dt
        plt.plot(t, v, color='black')
        plt.title(str(sn))

        spikes = get_spikes(sweep_features, sn)
        nspikes = len(spikes)

        delta_v = 5.0

        if nspikes:
            if type_name != "long_square" and nspikes:

                voltages = [spikes[0][f] for f in voltage_features]
                times = [spikes[0][f] for f in time_features]
            else:
                rheo_sn = cell_features["long_squares"]["rheobase_sweep"][
                    "sweep_number"]
                rheo_spike = get_spikes(sweep_features, rheo_sn)[0]
                voltages = [rheo_spike[f] for f in voltage_features]
                times = [rheo_spike[f] for f in time_features]

            plt.scatter(times, voltages, color='red', zorder=20)

            plt.plot([
                spikes[0]['upstroke_t'] - 1e-3 *
                (delta_v / spikes[0]['upstroke']), spikes[0]['upstroke_t'] +
                1e-3 * (delta_v / spikes[0]['upstroke'])
            ], [
                spikes[0]['upstroke_v'] - delta_v,
                spikes[0]['upstroke_v'] + delta_v
            ],
                     color='red')

            if 'downstroke_t' in spikes[0]:
                plt.plot([
                    spikes[0]['downstroke_t'] - 1e-3 *
                    (delta_v / spikes[0]['downstroke']),
                    spikes[0]['downstroke_t'] + 1e-3 *
                    (delta_v / spikes[0]['downstroke'])
                ], [
                    spikes[0]['downstroke_v'] - delta_v,
                    spikes[0]['downstroke_v'] + delta_v
                ],
                         color='red')

        if type_name == "ramp":
            if nspikes:
                plt.xlim(spikes[0]["threshold_t"] - 0.002,
                         spikes[0]["fast_trough_t"] + 0.01)
        elif type_name == "short_square":
            plt.xlim(stim_start_shifted - 0.002,
                     stim_start_shifted + stim_dur + 0.01)
        elif type_name == "long_square":
            plt.xlim(times[0] - 0.002, times[-2] + 0.002)

        plt.tight_layout()

    return figs
Beispiel #6
0
def load_sweep(data_set, sweep_number):
    sweep = data_set.sweep(sweep_number)
    dt = sweep.t[1] - sweep.t[0]
    r = ep.get_experiment_epoch(sweep.i, sweep.sampling_rate)

    return (sweep.v, sweep.i, sweep.t, r, dt)