Exemplo n.º 1
0
def preprocess_short_square_sweeps(data_set,
                                   sweep_numbers,
                                   extra_dur=0.2,
                                   spike_window=0.05):
    if len(sweep_numbers) == 0:
        raise er.FeatureError(
            "No short square sweeps available for feature extraction")

    good_sweep_numbers, ssq_start, ssq_end = validate_sweeps(
        data_set, sweep_numbers, extra_dur=extra_dur)
    if len(good_sweep_numbers) == 0:
        raise er.FeatureError(
            "No short square sweeps were long enough or did not end early")
    ssq_sweeps = data_set.sweep_set(good_sweep_numbers)

    ssq_spx, ssq_spfx = dsf.extractors_for_sweeps(
        ssq_sweeps,
        est_window=[ssq_start, ssq_start + 0.001],
        start=ssq_start,
        end=ssq_end + spike_window,
        reject_at_stim_start_interval=0.0002,
        **dsf.detection_parameters(data_set.SHORT_SQUARE))
    ssq_an = spa.ShortSquareAnalysis(ssq_spx, ssq_spfx)
    ssq_features = ssq_an.analyze(ssq_sweeps)

    return ssq_sweeps, ssq_features
Exemplo n.º 2
0
def feature_vector_input():

    TEST_DATA_PATH = os.path.join(os.path.dirname(__file__), 'data')

    nwb_file_name = "Pvalb-IRES-Cre;Ai14-415796.02.01.01.nwb"
    nwb_file_full_path = os.path.join(TEST_DATA_PATH, nwb_file_name)

    if not os.path.exists(nwb_file_full_path):
        download_file(nwb_file_name, nwb_file_full_path)

    data_set = AibsDataSet(nwb_file=nwb_file_full_path, ontology=ontology)

    lsq_sweep_numbers = data_set.filtered_sweep_table(
        clamp_mode=data_set.CURRENT_CLAMP,
        stimuli=ontology.long_square_names).sweep_number.sort_values().values

    lsq_sweeps = data_set.sweep_set(lsq_sweep_numbers)
    lsq_start, lsq_dur, _, _, _ = stf.get_stim_characteristics(
        lsq_sweeps.sweeps[0].i, lsq_sweeps.sweeps[0].t)

    lsq_end = lsq_start + lsq_dur
    lsq_spx, lsq_spfx = dsf.extractors_for_sweeps(lsq_sweeps,
                                                  start=lsq_start,
                                                  end=lsq_end,
                                                  **dsf.detection_parameters(
                                                      data_set.LONG_SQUARE))
    lsq_an = spa.LongSquareAnalysis(lsq_spx, lsq_spfx, subthresh_min_amp=-100.)

    lsq_features = lsq_an.analyze(lsq_sweeps)

    return lsq_sweeps, lsq_features, lsq_start, lsq_end
Exemplo n.º 3
0
def preprocess_long_square_sweeps(data_set,
                                  sweep_numbers,
                                  extra_dur=0.2,
                                  subthresh_min_amp=-100.):
    if len(sweep_numbers) == 0:
        raise er.FeatureError(
            "No long square sweeps available for feature extraction")

    good_sweep_numbers, lsq_start, lsq_end = validate_sweeps(
        data_set, sweep_numbers, extra_dur=extra_dur)
    if len(good_sweep_numbers) == 0:
        raise er.FeatureError(
            "No long square sweeps were long enough or did not end early")
    lsq_sweeps = data_set.sweep_set(good_sweep_numbers)

    lsq_spx, lsq_spfx = dsf.extractors_for_sweeps(lsq_sweeps,
                                                  start=lsq_start,
                                                  end=lsq_end,
                                                  min_peak=-25,
                                                  **dsf.detection_parameters(
                                                      data_set.LONG_SQUARE))
    lsq_an = spa.LongSquareAnalysis(lsq_spx,
                                    lsq_spfx,
                                    subthresh_min_amp=subthresh_min_amp)
    lsq_features = lsq_an.analyze(lsq_sweeps)

    return lsq_sweeps, lsq_features, lsq_start, lsq_end, lsq_spx
Exemplo n.º 4
0
def feature_vector_input():

    TEST_DATA_PATH = os.path.join(os.path.dirname(__file__), 'data')

    nwb_file_name = "Pvalb-IRES-Cre;Ai14-415796.02.01.01.nwb"
    nwb_file_full_path = os.path.join(TEST_DATA_PATH, nwb_file_name)

    if not os.path.exists(nwb_file_full_path):
        download_file(nwb_file_name, nwb_file_full_path)

    data_set = AibsDataSet(nwb_file=nwb_file_full_path, ontology=ontology)

    lsq_sweep_numbers = [4, 5, 6, 16, 17, 18, 19, 20, 21]

    lsq_sweeps = data_set.sweep_set(lsq_sweep_numbers)
    lsq_sweeps.select_epoch("recording")
    lsq_sweeps.align_to_start_of_epoch("experiment")
    lsq_start, lsq_dur, _, _, _ = stf.get_stim_characteristics(lsq_sweeps.sweeps[0].i,
                                                               lsq_sweeps.sweeps[0].t)

    lsq_end = lsq_start + lsq_dur
    lsq_spx, lsq_spfx = dsf.extractors_for_sweeps(lsq_sweeps,
                                                  start=lsq_start,
                                                  end=lsq_end,
                                                  **dsf.detection_parameters(data_set.LONG_SQUARE))
    lsq_an = spa.LongSquareAnalysis(lsq_spx, lsq_spfx, subthresh_min_amp=-100.)

    lsq_features = lsq_an.analyze(lsq_sweeps)

    return lsq_sweeps, lsq_features, lsq_start, lsq_end
Exemplo n.º 5
0
def preprocess_ramp_sweeps(data_set, sweep_numbers):
    if len(sweep_numbers) == 0:
        raise er.FeatureError("No ramp sweeps available for feature extraction")

    ramp_sweeps = data_set.sweep_set(sweep_numbers)
    ramp_sweeps.select_epoch("recording")

    ramp_start, ramp_dur, _, _, _ = stf.get_stim_characteristics(ramp_sweeps.sweeps[0].i, ramp_sweeps.sweeps[0].t)
    ramp_spx, ramp_spfx = dsf.extractors_for_sweeps(ramp_sweeps,
                                                start = ramp_start,
                                                **dsf.detection_parameters(data_set.RAMP))
    ramp_an = spa.RampAnalysis(ramp_spx, ramp_spfx)
    ramp_features = ramp_an.analyze(ramp_sweeps)

    return ramp_sweeps, ramp_features, ramp_an
Exemplo n.º 6
0
    def create_db_entries(cls, job, session):
        errors = []
        db = job['database']
        job_id = job['job_id']

        # Load experiment from DB
        expt = db.experiment_from_timestamp(job_id, session=session)
        nwb = expt.data
        if nwb is None:
            raise Exception('No NWB data for this experiment')
        sweeps = nwb.contents

        n_cells = len(expt.cell_list)
        ipfx_fail = 0
        for cell in expt.cell_list:
            dev_id = cell.electrode.device_id
            target_v, if_curve = get_lp_sweeps(sweeps, dev_id)
            lp_sweeps = target_v + if_curve
            if len(lp_sweeps) == 0:
                errors.append('No long pulse sweeps for cell %s' % cell.ext_id)
                continue
            recs = [rec[dev_id] for rec in lp_sweeps]
            min_pulse_dur = np.inf
            sweep_list = []
            for rec in recs:
                if rec.clamp_mode != 'ic':
                    continue

                db_rec = get_db_recording(expt, rec)
                if db_rec is None or db_rec.patch_clamp_recording.qc_pass is False:
                    continue

                pulse_times = get_pulse_times(rec)
                if pulse_times is None:
                    continue
                
                # pulses may have different durations as well, so we just use the smallest duration
                start, end = pulse_times
                min_pulse_dur = min(min_pulse_dur, end-start)
                
                sweep = MPSweep(rec, -start)
                if sweep is None:
                    continue
                sweep_list.append(sweep)
            
            if len(sweep_list) == 0:
                errors.append('No sweeps passed qc for cell %s' % cell.ext_id)
                continue

            sweep_set = SweepSet(sweep_list)    
            spx, spfx = extractors_for_sweeps(sweep_set, start=0, end=min_pulse_dur)
            lsa = LongSquareAnalysis(spx, spfx, subthresh_min_amp=-200)
            
            try:
                analysis = lsa.analyze(sweep_set)
            except Exception as exc:
                errors.append('Error running IPFX analysis for cell %s: %s' % (cell.ext_id, str(exc)))
                ipfx_fail += 1
                continue

            if ipfx_fail == n_cells and n_cells > 1:
                raise Exception('All cells failed IPFX analysis')
                continue
            
            spike_features = lsa.mean_features_first_spike(analysis['spikes_set'])
            up_down = spike_features['upstroke_downstroke_ratio']
            rheo = analysis['rheobase_i']
            fi_slope = analysis['fi_fit_slope']
            input_r = analysis['input_resistance']
            sag = analysis['sag']
            avg_rate = np.mean(analysis['spiking_sweeps'].avg_rate)
            adapt = np.mean(analysis['spiking_sweeps'].adapt)
            
            results = {
                'upstroke_downstroke_ratio': up_down,
                'rheobase': rheo,
                'fi_slope': fi_slope,
                'input_resistance': input_r,
                'sag': sag,
                'avg_firing_rate': avg_rate,
                'adaptation_index': adapt,
            }

            # Write new record to DB
            conn = db.Intrinsic(cell_id=cell.id, **results)
            session.add(conn)

        return errors
Exemplo n.º 7
0
def extract_features(data_set, ramp_sweep_numbers, ssq_sweep_numbers, lsq_sweep_numbers,
                     amp_interval=20, max_above_rheo=100):
    features = {}
    # RAMP FEATURES -----------------
    if len(ramp_sweep_numbers) > 0:
        ramp_sweeps = data_set.sweep_set(ramp_sweep_numbers)

        ramp_start, ramp_dur, _, _, _ = stf.get_stim_characteristics(ramp_sweeps.sweeps[0].i, ramp_sweeps.sweeps[0].t)
        ramp_spx, ramp_spfx = dsf.extractors_for_sweeps(ramp_sweeps,
                                                    start = ramp_start,
                                                    **dsf.detection_parameters(data_set.RAMP))
        ramp_an = spa.RampAnalysis(ramp_spx, ramp_spfx)
        basic_ramp_features = ramp_an.analyze(ramp_sweeps)
        first_spike_ramp_features = first_spike_ramp(ramp_an)
        features.update(first_spike_ramp_features)

    # SHORT SQUARE FEATURES -----------------
    if len(ssq_sweep_numbers) > 0:
        ssq_sweeps = data_set.sweep_set(ssq_sweep_numbers)

        ssq_start, ssq_dur, _, _, _ = stf.get_stim_characteristics(ssq_sweeps.sweeps[0].i, ssq_sweeps.sweeps[0].t)
        ssq_spx, ssq_spfx = dsf.extractors_for_sweeps(ssq_sweeps,
                                                      est_window = [ssq_start, ssq_start+0.001],
                                                      **dsf.detection_parameters(data_set.SHORT_SQUARE))
        ssq_an = spa.ShortSquareAnalysis(ssq_spx, ssq_spfx)
        basic_ssq_features = ssq_an.analyze(ssq_sweeps)
        first_spike_ssq_features = first_spike_ssq(ssq_an)
        first_spike_ssq_features["short_square_current"] = basic_ssq_features["stimulus_amplitude"]
        features.update(first_spike_ssq_features)

    # LONG SQUARE SUBTHRESHOLD FEATURES -----------------
    if len(lsq_sweep_numbers) > 0:
        check_lsq_sweeps = data_set.sweep_set(lsq_sweep_numbers)
        lsq_start, lsq_dur, _, _, _ = stf.get_stim_characteristics(check_lsq_sweeps.sweeps[0].i, check_lsq_sweeps.sweeps[0].t)

        # Check that all sweeps are long enough and not ended early
        extra_dur = 0.2
        good_lsq_sweep_numbers = [n for n, s in zip(lsq_sweep_numbers, check_lsq_sweeps.sweeps)
                                  if s.t[-1] >= lsq_start + lsq_dur + extra_dur and not np.all(s.v[tsu.find_time_index(s.t, lsq_start + lsq_dur)-100:tsu.find_time_index(s.t, lsq_start + lsq_dur)] == 0)]
        lsq_sweeps = data_set.sweep_set(good_lsq_sweep_numbers)

        lsq_spx, lsq_spfx = dsf.extractors_for_sweeps(lsq_sweeps,
                                                      start = lsq_start,
                                                      end = lsq_start + lsq_dur,
                                                      **dsf.detection_parameters(data_set.LONG_SQUARE))
        lsq_an = spa.LongSquareAnalysis(lsq_spx, lsq_spfx, subthresh_min_amp=-100.)
        basic_lsq_features = lsq_an.analyze(lsq_sweeps)
        features.update({
            "input_resistance": basic_lsq_features["input_resistance"],
            "tau": basic_lsq_features["tau"],
            "v_baseline": basic_lsq_features["v_baseline"],
            "sag_nearest_minus_100": basic_lsq_features["sag"],
            "sag_measured_at": basic_lsq_features["vm_for_sag"],
            "rheobase_i": int(basic_lsq_features["rheobase_i"]),
            "fi_linear_fit_slope": basic_lsq_features["fi_fit_slope"],
        })

        # TODO (maybe): port sag_from_ri code over

        # Identify suprathreshold set for analysis
        sweep_table = basic_lsq_features["spiking_sweeps"]
        mask_supra = sweep_table["stim_amp"] >= basic_lsq_features["rheobase_i"]
        sweep_indexes = fv._consolidated_long_square_indexes(sweep_table.loc[mask_supra, :])
        amps = np.rint(sweep_table.loc[sweep_indexes, "stim_amp"].values - basic_lsq_features["rheobase_i"])
        spike_data = np.array(basic_lsq_features["spikes_set"])

        for amp, swp_ind in zip(amps, sweep_indexes):
            if (amp % amp_interval != 0) or (amp > max_above_rheo) or (amp < 0):
                continue
            amp_label = int(amp / amp_interval)

            first_spike_lsq_sweep_features = first_spike_lsq(spike_data[swp_ind])
            features.update({"ap_1_{:s}_{:d}_long_square".format(f, amp_label): v
                             for f, v in first_spike_lsq_sweep_features.items()})

            mean_spike_lsq_sweep_features = mean_spike_lsq(spike_data[swp_ind])
            features.update({"ap_mean_{:s}_{:d}_long_square".format(f, amp_label): v
                             for f, v in mean_spike_lsq_sweep_features.items()})

            sweep_feature_list = [
                "first_isi",
                "avg_rate",
                "isi_cv",
                "latency",
                "median_isi",
                "adapt",
            ]

            features.update({"{:s}_{:d}_long_square".format(f, amp_label): sweep_table.at[swp_ind, f]
                             for f in sweep_feature_list})
            features["stimulus_amplitude_{:d}_long_square".format(amp_label)] = int(amp + basic_lsq_features["rheobase_i"])

        rates = sweep_table.loc[sweep_indexes, "avg_rate"].values
        features.update(fi_curve_fit(amps, rates))

    return features
Exemplo n.º 8
0
def get_long_square_features(recordings, cell_id=''):
    errors = []
    if len(recordings) == 0:
        errors.append('No long pulse sweeps for cell %s' % cell_id)
        return {}, errors

    min_pulse_dur = np.inf
    sweep_list = []
    for rec in recordings:
        pulse_times = get_pulse_times(rec)
        if pulse_times is None:
            continue

        # pulses may have different durations as well, so we just use the smallest duration
        start, end = pulse_times
        min_pulse_dur = min(min_pulse_dur, end - start)

        sweep = MPSweep(rec, -start)
        if sweep is not None:
            sweep_list.append(sweep)

    if len(sweep_list) == 0:
        errors.append('No long square sweeps passed qc for cell %s' % cell_id)
        return {}, errors

    sweep_set = SweepSet(sweep_list)
    spx, spfx = extractors_for_sweeps(sweep_set, start=0, end=min_pulse_dur)
    lsa = LongSquareAnalysis(spx,
                             spfx,
                             subthresh_min_amp=-200,
                             require_subthreshold=False,
                             require_suprathreshold=False)

    try:
        analysis = lsa.analyze(sweep_set)
    except FeatureError as exc:
        err = f'Error running long square analysis for cell {cell_id}: {str(exc)}'
        logger.warning(err)
        errors.append(err)
        return {}, errors

    analysis_dict = lsa.as_dict(analysis)
    output = get_complete_long_square_features(analysis_dict)

    results = {
        'rheobase':
        output.get('rheobase_i', np.nan) * 1e-12,  #unscale from pA,
        'fi_slope':
        output.get('fi_fit_slope', np.nan) * 1e-12,  #unscale from pA,
        'input_resistance':
        output.get('input_resistance', np.nan) * 1e6,  #unscale from MOhm,
        'input_resistance_ss':
        output.get('input_resistance_ss', np.nan) * 1e6,  #unscale from MOhm,
        'tau':
        output.get('tau', np.nan),
        'sag':
        output.get('sag', np.nan),
        'sag_peak_t':
        output.get('sag_peak_t', np.nan),
        'sag_depol':
        output.get('sag_depol', np.nan),
        'sag_peak_t_depol':
        output.get('sag_peak_t_depol', np.nan),
        'ap_upstroke_downstroke_ratio':
        output.get('upstroke_downstroke_ratio_hero', np.nan),
        'ap_upstroke':
        output.get('upstroke_hero', np.nan) * 1e-3,  #unscale from mV
        'ap_downstroke':
        output.get('downstroke_hero', np.nan) * 1e-3,  #unscale from mV
        'ap_width':
        output.get('width_hero', np.nan),
        'ap_threshold_v':
        output.get('threshold_v_hero', np.nan) * 1e-3,  #unscale from mV
        'ap_peak_deltav':
        output.get('peak_deltav_hero', np.nan) * 1e-3,  #unscale from mV
        'ap_fast_trough_deltav':
        output.get('fast_trough_deltav_hero', np.nan) * 1e-3,  #unscale from mV
        'firing_rate_rheo':
        output.get('avg_rate_rheo', np.nan),
        'latency_rheo':
        output.get('latency_rheo', np.nan),
        'firing_rate_40pa':
        output.get('avg_rate_hero', np.nan),
        'latency_40pa':
        output.get('latency_hero', np.nan),
        'adaptation_index':
        output.get('adapt_mean', np.nan),
        'isi_cv':
        output.get('isi_cv_mean', np.nan),
        'isi_adapt_ratio':
        output.get('isi_adapt_ratio', np.nan),
        'upstroke_adapt_ratio':
        output.get('upstroke_adapt_ratio', np.nan),
        'downstroke_adapt_ratio':
        output.get('downstroke_adapt_ratio', np.nan),
        'width_adapt_ratio':
        output.get('width_adapt_ratio', np.nan),
        'threshold_v_adapt_ratio':
        output.get('threshold_v_adapt_ratio', np.nan),
    }
    return results, errors