def plot_stim_aligned(f_list,
                      stim_chan,
                      xii_chan,
                      dia_chan,
                      pre=0.1,
                      post=0.5):

    xii_slice = []
    dia_slice = []
    for f in f_list:
        dat = pyabf.ABF(f)
        sr = dat.dataRate
        pre_samp = int(pre * sr)
        post_samp = int(post * sr)
        tvec = np.arange(-pre, post, 1 / sr) * 1000

        stim_id = abf.get_channel_id_by_label(dat, stim_chan)
        xii_id = abf.get_channel_id_by_label(dat, xii_chan)
        dia_id = abf.get_channel_id_by_label(dat, dia_chan)

        for ii in range(dat.sweepCount):
            dat.setSweep(ii, stim_id)
            stim_on, stim_off = rlab_signal.binary_onsets(dat.sweepY, 1)
            stim_on = stim_on[0]
            stim_off = stim_off[0]
            dat.setSweep(ii, xii_id)
            xii_data = dat.sweepY
            xii_slice.append(xii_data[stim_on - pre_samp:stim_on + post_samp])

            dat.setSweep(ii, dia_id)
            dia_data = dat.sweepY
            dia_data = rlab_signal.remove_EKG(dia_data, sr)
            dia_data = rlab_signal.integrator(dia_data,
                                              sr,
                                              0.05,
                                              acausal=False)
            dia_slice.append(dia_data[stim_on - pre_samp:stim_on + post_samp])

    xii_slice = np.array(xii_slice).T
    dia_slice = np.array(dia_slice).T

    f, ax = plt.subplots(nrows=2, sharex=True, figsize=(3, 2))
    ax[0].plot(tvec, xii_slice, 'k', alpha=0.2, lw=1)
    ax[0].axvspan(0, (stim_off - stim_on) / sr * 1000, color='c', alpha=0.3)

    ax[1].plot(tvec, dia_slice, 'k', alpha=0.2, lw=1)
    ax[1].axvspan(0, (stim_off - stim_on) / sr * 1000, color='c', alpha=0.3)
    ax[1].set_xlabel('Time (ms)')

    ax[0].set_ylabel('XII')
    ax[1].set_ylabel('Dia')
    sns.despine()
    plt.tight_layout()
Ejemplo n.º 2
0
def get_amplitude(f_list, sig_chan):
    yy = []
    for f in f_list:
        dat = pyabf.ABF(f)
        sig_id = abf.get_channel_id_by_label(dat, sig_chan)
        yy.append(dat.sweepY)

    y_med = np.median(yy)
    y_std = np.percentile(yy, .75)
    return (y_med, y_std)
Ejemplo n.º 3
0
def compute_sigh(f_list, stim_chan, sig_chan, tag):
    # ===================== ===================== #
    # Initialize parameters #
    # ===================== ===================== #
    pre_win = 0.5
    post_win = 2.5
    amp_thresh = 1.5  # Number of IQRs greater than the base medinan a burst needs to be
    prom_mult = 2.5  # Bigger number misses small bursts
    iqr_thresh = 4.  # Bigger number calls more evoked bursts eupneas
    lockout = 1.  # seconds
    p_fig = r'../figs'
    p_results = r'../results'
    used_feats = ['Amplitude', 'Next IBI',
                  'Width']  # What to plot in the feature space

    # ===================== ===================== #
    # Init data collectors
    all_proms = []
    all_widths = []
    all_ons = []
    all_offs = []
    all_next = []
    all_prev = []
    all_pks = []
    all_area = []
    all_evoked_shape = []

    base_pks = []
    base_proms = []
    base_widths = []
    base_ons = []
    base_offs = []
    base_prev = []
    base_next = []
    base_area = []
    all_base_shapes = []

    ev_sweep_ID = []
    ev_file_ID = []

    base_sweep_ID = []
    base_file_ID = []

    # ====================== ==================== #
    # collect bursts
    nstims = 0
    fig = plt.figure(figsize=(3, 7))
    subfig_path = os.path.join(p_fig, tag)
    if os.path.isdir(subfig_path):
        print(f'Already finished {tag}')
        return (0)
    os.mkdir(subfig_path)
    # y_med,y_std = get_amplitude(f_list,sig_chan)

    for f in f_list:
        dat = pyabf.ABF(f)
        print(f)
        stim_id = abf.get_channel_id_by_label(dat, stim_chan)
        sig_id = abf.get_channel_id_by_label(dat, sig_chan)
        for ii in range(dat.sweepCount):
            nstims += 1

            # Get stimulus
            dat.setSweep(ii, stim_id)
            stim_vals = sr_conversion(dat.sweepY, dat.dataRate)[0]
            stim_samp = rlab_signal.binary_onsets(stim_vals, 1.)[0][0]
            dat.setSweep(ii, sig_id)

            # Convert if not sampling at 1000
            y = dat.sweepY
            sr = dat.dataRate
            y, sr = sr_conversion(y, sr)
            pre_win_samp = int(pre_win * sr)
            post_win_samp = int(post_win * sr)

            # Find peaks and props
            pks, props = scipy.signal.find_peaks(y,
                                                 distance=int(0.750 * sr),
                                                 prominence=np.std(y) *
                                                 prom_mult)
            plt.close('all')
            f2 = plt.figure()
            plt.plot(y, 'k')
            plt.plot(pks, y[pks], 'o')
            plt.savefig(os.path.join(subfig_path, f'{tag}_{nstims}_y.png'))
            plt.close('all')

            proms = props['prominences']
            _, _, onsets, offsets = scipy.signal.peak_widths(y,
                                                             pks,
                                                             rel_height=0.8)
            _, sprt, _, _ = scipy.signal.peak_widths(y, pks, rel_height=1)
            fwhm = scipy.signal.peak_widths(y, pks, rel_height=0.5)[0]
            onsets = onsets.astype('int')
            offsets = offsets.astype('int')

            # y_seg = y[(stim_samp - 2000):(stim_samp + 2000)]
            # y_seg += np.std(y_seg) * nstims

            y_plt = y.copy()
            y_plt[y_plt > 0.1] = np.nan
            # plt.plot(y_plt+0.01*nstims,'k',alpha=0.6,linewidth=1)

            # Get the index of the evoked burst
            evoked_bool = np.logical_and(pks > stim_samp, pks <
                                         (stim_samp + int(lockout * sr)))
            if not np.any(evoked_bool):
                continue
            evoked_idx = np.where(evoked_bool)[0][0]
            base_idx = np.where(np.invert(evoked_bool))[0]
            # plt.vlines(onsets[evoked_idx]-stim_samp+2000,ymin=np.min(y_seg),ymax = np.max(y_seg),color='c')
            # plt.plot(onsets[evoked_idx],y_plt[onsets[evoked_idx]]+0.01*nstims,'ro',markersize=2)

            # Get the shape of the evoked burst
            all_evoked_shape.append(y[onsets[evoked_idx] -
                                      pre_win_samp:onsets[evoked_idx] +
                                      post_win_samp] - sprt[evoked_idx])
            area = np.trapz(y[onsets[evoked_idx]:offsets[evoked_idx]])

            # Append to accumulators
            pk_time = (pks[evoked_idx] - onsets[evoked_idx]) / sr
            all_pks.append(pk_time)
            all_area.append(area)
            all_proms.append(proms[evoked_idx])
            all_widths.append(fwhm[evoked_idx] / sr)
            all_ons.append(onsets[evoked_idx] / sr)
            all_offs.append(offsets[evoked_idx] / sr)
            all_prev.append(
                (onsets[evoked_idx] - offsets[evoked_idx - 1]) / sr)
            ev_sweep_ID.append(ii)
            ev_file_ID.append(f)

            # append to baseline accumulators in sweep
            pk_time_base = (pks[base_idx] - onsets[base_idx]) / sr
            base_pks.append(pk_time_base)
            base_proms.append(proms[base_idx])
            base_widths.append(fwhm[base_idx] / sr)
            base_ons.append(onsets[base_idx] / sr)
            base_offs.append(offsets[base_idx] / sr)

            # if only one peak found, do not append a nan in the IBI
            if len(pks) == 1:
                base_next.append([])
                base_prev.append([])
            else:
                base_next.append(
                    np.concatenate([
                        (onsets[base_idx[:-1] + 1] - pks[base_idx[:-1]]) / sr,
                        [np.nan]
                    ]))
                base_prev.append(
                    np.concatenate([
                        [np.nan],
                        (onsets[base_idx[1:]] - offsets[base_idx[1:] - 1]) / sr
                    ]))

            # For each spont burst, get the shapes and metadata
            for jj in base_idx:
                base_sweep_ID.append(ii)
                base_file_ID.append(f)

                if onsets[jj] - pre_win_samp < 0:
                    all_base_shapes.append(
                        np.empty((pre_win_samp + post_win_samp)) * np.nan)
                    base_area.append(np.nan)
                elif onsets[jj] + post_win_samp > len(y):
                    all_base_shapes.append(
                        np.empty((pre_win_samp + post_win_samp)) * np.nan)
                    base_area.append(np.nan)
                else:
                    all_base_shapes.append(y[onsets[jj] -
                                             pre_win_samp:onsets[jj] +
                                             post_win_samp] - sprt[jj])
                    base_area.append(np.trapz(y[onsets[jj]:offsets[jj]]))

            if onsets.shape[0] == (evoked_idx + 1):
                all_next.append((y.shape[0] - pks[evoked_idx]) / sr)
            else:
                all_next.append(
                    (onsets[evoked_idx + 1] - pks[evoked_idx]) / sr)

    # plt.savefig(os.path.join(p_fig,f'{tag}_all_traces.png'),dpi=600)
    plt.close('all')
    # reformat burst data
    evoked_shapes = np.array(all_evoked_shape).T
    base_shapes = np.array(all_base_shapes).T

    evoked_feats = pd.DataFrame()
    evoked_feats['Amplitude'] = all_proms
    evoked_feats['Peak Time'] = all_pks
    evoked_feats['Width'] = all_widths
    evoked_feats['Onset'] = all_ons
    evoked_feats['Offset'] = all_offs
    evoked_feats['Prev IBI'] = all_prev
    evoked_feats['Next IBI'] = all_next
    evoked_feats['Area'] = all_area
    evoked_feats['Recording'] = 'evoked'
    evoked_feats['file ID'] = ev_file_ID
    evoked_feats['sweep ID'] = ev_sweep_ID

    base_feats = pd.DataFrame()
    base_feats['Amplitude'] = np.concatenate(base_proms)
    base_feats['Peak Time'] = np.concatenate(base_pks)
    base_feats['Width'] = np.concatenate(base_widths)
    base_feats['Onset'] = np.concatenate(base_ons)
    base_feats['Offset'] = np.concatenate(base_offs)
    base_feats['Prev IBI'] = np.concatenate(base_prev)
    base_feats['Next IBI'] = np.concatenate(base_next)
    base_feats['Area'] = base_area
    base_feats['Recording'] = 'spontaneous'
    base_feats['file ID'] = base_file_ID
    base_feats['sweep ID'] = base_sweep_ID

    # Remove gigantic crap
    base_feats = base_feats[base_feats['Amplitude'] < 20]
    evoked_feats = evoked_feats[evoked_feats['Amplitude'] < 20]

    # # Only keep bursts that get within 65% of the mean burst height at mean peak time
    # # # Excludes very slow bursts
    # m_peak_time = base_feats['Peak Time'].mean()
    # initial_rise = base_shapes[int((m_peak_time + pre_win) * sr)] / base_feats["Amplitude"].mean()
    # keep = initial_rise > 0.65
    # base_feats = base_feats.loc[keep, :]
    #
    # initial_rise = evoked_shapes[int((m_peak_time + pre_win) * sr)] / evoked_feats["Amplitude"].mean()
    # keep = initial_rise > 0.65
    # evoked_feats = evoked_feats.loc[keep, :]

    # Get rid of bursts that hit their peak in under 100ms
    keep = base_feats['Peak Time'] > 0.150
    base_feats = base_feats.loc[keep, :]

    keep = evoked_feats['Peak Time'] > 0.150
    evoked_feats = evoked_feats.loc[keep, :]

    base_shapes = base_shapes[:, base_feats.index]
    evoked_shapes = evoked_shapes[:, evoked_feats.index]

    assert (base_feats.shape[0] == base_shapes.shape[1])
    assert (evoked_feats.shape[0] == evoked_shapes.shape[1])

    ## Remove bursts that are all nan
    keep = np.all(np.isfinite(base_shapes), axis=0)
    base_shapes = base_shapes[:, keep]
    base_feats = base_feats[keep]

    keep = np.all(np.isfinite(evoked_shapes), axis=0)
    evoked_shapes = evoked_shapes[:, keep]
    evoked_feats = evoked_feats[keep]
    # base_shapes[np.isnan(base_shapes)] = 0
    # evoked_shapes[np.isnan(evoked_shapes)] = 0

    # # Keep shapes and dataframes aligned

    # =========================== #
    # Perform PCA Decomposition #
    # =========================== #
    # First get number of PCs needed
    dec = sklearn.decomposition.PCA()
    cat_shapes = np.hstack([base_shapes, evoked_shapes])
    dec.fit(cat_shapes[500:1500, :].T)
    n = np.where(np.cumsum(dec.explained_variance_ratio_) > 0.95)[0][0]

    # Then, rerun with that number of PCs
    dec = sklearn.decomposition.PCA(n)
    cat_shapes = np.hstack([base_shapes, evoked_shapes])
    aa = cat_shapes[500:1500, :] - np.min(cat_shapes[500:1500, :], axis=0)
    dec.fit(aa.T)
    bb = evoked_shapes[500:1500, :] - np.min(evoked_shapes[500:1500, :],
                                             axis=0)
    evoked_PCA = dec.transform(bb.T)
    cc = base_shapes[500:1500, :] - np.min(base_shapes[500:1500, :], axis=0)
    base_PCA = dec.transform(cc.T)

    # Fit a covariance to the PCA
    clf = sklearn.covariance.MinCovDet()
    clf.fit(base_PCA)
    D_base = clf.mahalanobis(base_PCA)
    D_evoked = clf.mahalanobis(evoked_PCA)

    # Get the outliers
    outlier_thresh = np.median(D_base) + scipy.stats.iqr(D_base) * iqr_thresh
    # outlier_thresh = np.median(np.log(D_base)) + scipy.stats.iqr(np.log(D_base)) * iqr_thresh

    plt.figure()
    plt.hist(np.log(D_base), 20, color='k', alpha=0.2)
    plt.hist(np.log(D_evoked), 20, color='c', alpha=0.4)
    plt.axvline(np.log(outlier_thresh), color='r', linewidth=3)
    plt.xlabel('Log Distance')
    plt.savefig(os.path.join(p_fig, f'{tag}_D_dists.png'), dpi=300)
    plt.close('all')

    # ==================== #
    # Boost with intuitive feats #
    # ==================== #

    clf = sklearn.covariance.MinCovDet()
    clf.fit(base_feats[used_feats].dropna())
    D_feats_evoked = clf.mahalanobis(evoked_feats[used_feats])
    D_feats_base = clf.mahalanobis(base_feats[used_feats].dropna())
    outlier_feats_thresh = np.median(
        D_feats_base) + scipy.stats.iqr(D_feats_base) * iqr_thresh

    plt.figure()
    plt.hist(np.log(D_feats_base), 20, color='k', alpha=0.2)
    plt.hist(np.log(D_feats_evoked), 20, color='c', alpha=0.4)
    plt.axvline(np.log(outlier_feats_thresh), color='r', linewidth=3)
    plt.xlabel('Log Distance')
    plt.savefig(os.path.join(p_fig, f'{tag}_D_features_dists.png'), dpi=300)
    plt.close('all')

    # Remove classified sighs that are not big enough
    # amp_IQR = base_feats['Amplitude'].quantile(.75)-base_feats['Amplitude'].quantile(.25)
    # crit_pctil = evoked_feats['Amplitude'] > (base_feats['Amplitude'].median() + amp_IQR*amp_thresh)
    crit_amp = evoked_feats['Amplitude'] > (
        base_feats['Amplitude'].quantile(0.9))
    outlier_feats = D_feats_evoked > outlier_feats_thresh
    outlier = D_evoked > (outlier_thresh)
    # outlier = np.log(D_evoked) > (outlier_thresh)
    crit = np.logical_and(outlier_feats, outlier)
    crit = np.logical_and(crit, crit_amp).values

    # add labels to dataframe
    evoked_feats['Type'] = 'eupnea'
    evoked_feats['Type'][crit] = 'sigh'
    evoked_feats['Distance'] = D_evoked
    evoked_feats['Feat Distance'] = D_feats_evoked
    for kk in range(n):
        evoked_feats[f'PCA{kk}'] = evoked_PCA[:, kk]

    base_feats['Distance'] = D_base
    base_feats['Feat Distance'] = np.nan
    base_feats['Feat Distance'][base_feats['Next IBI'].notna()] = D_feats_base
    base_feats['Type'] = 'eupnea'
    for kk in range(n):
        base_feats[f'PCA{kk}'] = base_PCA[:, kk]

    # Concatenate dataframe
    df_cat = pd.concat([base_feats, evoked_feats], sort=False)
    cat_shapes = np.hstack([base_shapes, evoked_shapes])

    # Plot feature space
    plt.close('all')
    try:
        sns.pairplot(df_cat[used_feats + ['Type', 'PCA0', 'PCA1']], hue='Type')
    except:
        sns.pairplot(df_cat[used_feats + ['Type', 'PCA0']], hue='Type')
    plt.savefig(os.path.join(p_fig, f'{tag}_feature_space.png'), dpi=300)
    plt.close('all')

    # Plot traces
    plt.figure(figsize=(4, 2))
    plt.plot(cat_shapes[:, df_cat['Type'] == 'eupnea'], 'k', alpha=0.1)
    try:
        plt.plot(cat_shapes[:, df_cat['Type'] == 'sigh'], 'r', alpha=0.3)
    except:
        print('No Sighs found')
    plt.tight_layout()
    plt.ylim([np.min(base_shapes), df_cat['Amplitude'].max()])
    plt.savefig(os.path.join(p_fig, f'{tag}_traces.png'), dpi=300)

    # Normalize each burst to the mean eupnea
    # TODO: Normalize to baseline eupnea??
    eup_mean = df_cat.groupby('Type').mean()['Amplitude']['eupnea']
    df_cat["Amplitude Normed"] = df_cat["Amplitude"] / eup_mean

    # Create a shapes dataframe
    df_shapes = pd.DataFrame(cat_shapes)
    df_shapes.index /= sr
    df_shapes.index.name = 'Time'

    # TODO: write a summary datafile
    # Write data outputs

    assert (df_shapes.shape[1] == df_cat.shape[0])
    df_shapes.to_csv(os.path.join(p_results, f'{tag}_shapes.csv'))
    df_cat.to_csv(os.path.join(p_results, f'{tag}_features.csv'))
    df_summary = df_cat.groupby(['Recording', 'Type']).describe()
    df_summary.to_excel(os.path.join(p_results, f'{tag}_summary.xls'))
    plt.close('all')
def compute_sigh(f_list, stim_chan, xii_chan, dia_chan, tag, condition):
    # ===================== ===================== #
    # Initialize parameters #
    # ===================== ===================== #
    pre_win = 0.2
    post_win = 0.3
    mad_thresh = 5
    lockout = .5  # seconds
    today = datetime.datetime.now().isoformat()[:10]
    p_fig = rf'../../figs/{today}_vivo_sweeps/{tag}'
    p_results = rf'../../results/{today}-classify_sigh_sweeps_invivo/{tag}'

    # ====================== ==================== #
    # collect bursts
    nstims = 0
    fig = plt.figure(figsize=(3, 7))
    subfig_path = os.path.join(p_fig, 'traces')
    if os.path.isdir(subfig_path):
        print(f'Already finished {tag}')
    else:
        os.makedirs(subfig_path)
    if os.path.isdir(p_results):
        print(f'Already finished {tag}')
    else:
        os.makedirs(p_results)

    all_xii_df = pd.DataFrame()
    all_dia_df = pd.DataFrame()
    all_xii_shapes = []
    all_dia_shapes = []

    time_to_evoked_burst = []
    all_burst_stim_delay = []
    all_pre_stim_IBI = []
    all_post_stim_IBI = []

    for f in f_list:
        dat = pyabf.ABF(f)
        sr = dat.dataRate
        last_time = dat.sweepX[-1] - 0.5
        print(f)
        stim_id = abf.get_channel_id_by_label(dat, stim_chan)
        xii_id = abf.get_channel_id_by_label(dat, xii_chan)
        dia_id = abf.get_channel_id_by_label(dat, dia_chan)

        for ii in range(dat.sweepCount):
            nstims += 1
            print(ii)

            # Get stimulus
            dat.setSweep(ii, stim_id)
            tvec = dat.sweepX
            stim_vals = dat.sweepY
            stim_samp = rlab_signal.binary_onsets(stim_vals, 1.)[0][0]

            # Get integrated diaphragm
            dat.setSweep(ii, dia_id)
            tvec = dat.sweepX
            dia = dat.sweepY
            dia = rlab_signal.remove_EKG(dia, sr)
            dia_int = rlab_signal.integrator(dia, sr, span=0.05, acausal=False)
            dia_int_f = dia_int - np.min(dia_int[1000:-1000])
            dia_df = burst_stats_dia(dia_int_f, sr)
            dia_df = get_sigh_idx(dia_df, thresh=mad_thresh)

            dia_shapes = burst.get_burst_shape(dia_int,
                                               dia_df['on_sec'].values,
                                               dia_df['off_sec'].values,
                                               sr,
                                               pre_win=pre_win,
                                               post_win=post_win)
            # Get xii
            dat.setSweep(ii, xii_id)
            xii_int = dat.sweepY
            xii_shapes = burst.get_burst_shape(xii_int,
                                               dia_df['on_sec'].values,
                                               dia_df['off_sec'].values,
                                               sr,
                                               pre_win=pre_win,
                                               post_win=post_win)
            xii_amp, xii_dur = shapes_to_stats(xii_shapes, sr)
            xii_df = pd.DataFrame()
            xii_df['amp'] = xii_amp
            xii_df['duration_sec'] = xii_dur

            # Get the index of the evoked burst
            burst_idx_post_stim = np.where(
                dia_df['pk_samp'].values > stim_samp)[0][0]
            burst_stim_delay = (dia_df['on_samp'].values[burst_idx_post_stim] -
                                stim_samp) / sr
            pk_time_delay = (dia_df['pk_samp'].values[burst_idx_post_stim] -
                             stim_samp) / sr
            stim_on, stim_off = rlab_signal.binary_onsets(stim_vals, 1)
            evoked_bool = np.zeros(dia_df.shape[0], dtype='bool')
            if pk_time_delay < lockout:
                evoked_bool[burst_idx_post_stim] = True
                dot_c = 'cs'
            else:
                dot_c = 'rs'

            # Find peaks and props
            plt.close('all')
            f2, ax2 = plt.subplots(nrows=2, ncols=2)
            ax2[0, 0].plot(tvec, xii_int, 'k')
            ax2[0, 0].plot(dia_df['on_sec'], xii_int[dia_df['on_samp']], 'go')
            ax2[0, 0].plot(dia_df['on_sec'][burst_idx_post_stim],
                           xii_int[dia_df['on_samp'][burst_idx_post_stim]],
                           dot_c,
                           ms=10)

            ax2[1, 0].plot(tvec, dia_int, 'k')
            ax2[1, 0].plot(dia_df['on_sec'], dia_int[dia_df['on_samp']], 'go')
            ax2[1, 0].plot(dia_df['on_sec'][burst_idx_post_stim],
                           dia_int[dia_df['on_samp'][burst_idx_post_stim]],
                           dot_c,
                           ms=10)

            ax2[0, 0].axvspan(tvec[stim_on],
                              tvec[stim_off],
                              color='c',
                              alpha=0.5)
            ax2[1, 0].axvspan(tvec[stim_on],
                              tvec[stim_off],
                              color='c',
                              alpha=0.5)

            ax2[0, 0].set_ylabel('Hypoglossal')
            ax2[1, 0].set_ylabel('Diaphragm')

            ax2[0, 1].plot(tvec, xii_int, 'k')
            ax2[0, 1].plot(dia_df['on_sec'], xii_int[dia_df['on_samp']], 'go')
            ax2[0, 1].plot(dia_df['on_sec'][burst_idx_post_stim],
                           xii_int[dia_df['on_samp'][burst_idx_post_stim]],
                           dot_c,
                           ms=10)

            ax2[1, 1].plot(tvec, dia_int, 'k')
            ax2[1, 1].plot(dia_df['on_sec'], dia_int[dia_df['on_samp']], 'go')
            ax2[1, 1].plot(dia_df['on_sec'][burst_idx_post_stim],
                           dia_int[dia_df['on_samp'][burst_idx_post_stim]],
                           dot_c,
                           ms=10)

            ax2[0, 1].axvspan(tvec[stim_on],
                              tvec[stim_off],
                              color='c',
                              alpha=0.5)
            ax2[1, 1].axvspan(tvec[stim_on],
                              tvec[stim_off],
                              color='c',
                              alpha=0.5)
            ax2[0, 1].set_xlim(tvec[stim_on] - 1, tvec[stim_off] + 1)
            ax2[1, 1].set_xlim(tvec[stim_on] - 1, tvec[stim_off] + 1)

            ax2[1, 0].set_xlabel('Time (s)')
            ax2[1, 1].set_xlabel('Time (s)')
            sns.despine()
            plt.tight_layout()
            plt.savefig(os.path.join(subfig_path, f'{tag}_{nstims-1}_y.png'),
                        dpi=150)
            plt.close('all')

            xii_df['evoked'] = evoked_bool
            dia_df['evoked'] = evoked_bool

            xii_df['sweep'] = ii
            dia_df['sweep'] = ii

            xii_df['file'] = f
            dia_df['file'] = f

            xii_df['stim'] = nstims
            dia_df['stim'] = nstims

            time_to_evoked_burst.append(
                (dia_df['on_samp'].values[evoked_bool] - stim_samp) / sr)

            all_xii_df = pd.concat([all_xii_df, xii_df])
            all_dia_df = pd.concat([all_dia_df, dia_df])

            all_xii_shapes.append(xii_shapes)
            all_dia_shapes.append(dia_shapes)

            # append information about post_stim behavior
            all_burst_stim_delay.append(burst_stim_delay)
            all_pre_stim_IBI.append(dia_df['postBI'][burst_idx_post_stim - 1])
            all_post_stim_IBI.append(dia_df['postBI'][burst_idx_post_stim])

    temp_time2evoked = pd.DataFrame()
    temp_time2evoked['evoked_delay'] = all_burst_stim_delay
    temp_time2evoked['pre_IBI'] = all_pre_stim_IBI
    temp_time2evoked['post_IBI'] = all_post_stim_IBI
    temp_time2evoked['condition'] = condition
    temp_time2evoked.to_csv(os.path.join(p_results, f'evoked_delay_{tag}.csv'))
    plt.close('all')
    # reformat burst data
    xii_shapes = np.hstack(all_xii_shapes)
    dia_shapes = np.hstack(all_dia_shapes)

    feats = pd.DataFrame()
    feats['dia_amp'] = all_dia_df['amp']
    feats['dia_dur'] = all_dia_df['duration_sec']
    feats['dia_auc'] = all_dia_df['auc']
    feats['xii_amp'] = all_xii_df['amp']
    feats['xii_dur'] = all_xii_df['duration_sec']
    feats['peak_time'] = all_dia_df['pk_time'] - all_dia_df['on_sec']
    feats['post_IBI'] = all_dia_df['postBI']
    feats['evoked'] = all_dia_df['evoked']
    feats['sweep_no'] = all_dia_df['stim']
    feats['absolute_peak_time'] = all_dia_df['pk_time']
    feats['Type'] = all_dia_df['type']
    feats = feats.reset_index(drop=True)

    # # Remove gigantic crap
    # keep = feats['dia_amp']<30
    # feats = feats.loc[keep]
    # xii_shapes = xii_shapes[:,keep]
    # dia_shapes = dia_shapes[:,keep]
    # feats = feats.reset_index(drop=True)

    # Get rid of bursts that hit their peak in under 10ms
    keep = feats['peak_time'] > 0.010
    feats = feats.loc[keep, :]
    xii_shapes = xii_shapes[:, keep]
    dia_shapes = dia_shapes[:, keep]
    feats = feats.reset_index(drop=True)

    ## Remove bursts that are all nan
    keep1 = np.all(np.isfinite(xii_shapes), axis=0)
    keep2 = np.all(np.isfinite(dia_shapes), axis=0)
    keep = np.logical_and(keep1, keep2)

    xii_shapes = xii_shapes[:, keep]
    dia_shapes = dia_shapes[:, keep]
    feats = feats[keep]
    feats = feats.reset_index(drop=True)

    base_feats = feats[feats.evoked == False]
    evoked_feats = feats[feats.evoked]

    dia_shapes_baseline = dia_shapes[:, feats.evoked == False]
    xii_shapes_baseline = xii_shapes[:, feats.evoked == False]

    dia_shapes_evoked = dia_shapes[:, feats.evoked]
    xii_shapes_evoked = xii_shapes[:, feats.evoked]

    # =========================== #
    # Perform PCA Decomposition on diaphragm#
    # =========================== #
    # # First get number of PCs needed
    # dec = sklearn.decomposition.PCA()
    # dec.fit(dia_shapes)
    # n = np.where(np.cumsum(dec.explained_variance_ratio_) > 0.95)[0][0]
    # if n == 1:
    #     n=2
    n = 3

    # Then, rerun with that number of PCs
    dec = sklearn.decomposition.PCA(n)
    aa = dia_shapes - np.min(dia_shapes, axis=0)
    dec.fit(aa.T)
    bb = dia_shapes_evoked - np.min(dia_shapes_evoked, axis=0)
    evoked_PCA_dia = dec.transform(bb.T)[:, :n]
    cc = dia_shapes_baseline - np.min(dia_shapes_baseline, axis=0)
    base_PCA_dia = dec.transform(cc.T)[:, :n]

    # =========================== #
    # Perform PCA Decomposition on prebot#
    # =========================== #
    # # First get number of PCs needed
    # dec = sklearn.decomposition.PCA()
    # dec.fit(xii_shapes)
    #
    # n = np.where(np.cumsum(dec.explained_variance_ratio_) > 0.95)[0][0]
    # if n == 1:
    #     n=2
    n = 3

    # Then, rerun with that number of PCs
    dec = sklearn.decomposition.PCA(n)
    aa = xii_shapes - np.min(xii_shapes, axis=0)
    dec.fit(aa.T)
    bb = xii_shapes_evoked - np.min(xii_shapes_evoked, axis=0)
    evoked_PCA_xii = dec.transform(bb.T)[:, :n]
    cc = xii_shapes_baseline - np.min(xii_shapes_baseline, axis=0)
    base_PCA_xii = dec.transform(cc.T)[:, :n]

    # combine PCAs
    evoked_PCA = np.hstack([evoked_PCA_dia, evoked_PCA_xii])
    base_PCA = np.hstack([base_PCA_dia, base_PCA_xii])

    for kk in range(evoked_PCA.shape[1]):
        evoked_feats[f'PCA{kk}'] = evoked_PCA[:, kk]

    for kk in range(base_PCA.shape[1]):
        base_feats[f'PCA{kk}'] = base_PCA[:, kk]

    # Concatenate dataframe
    df_cat = pd.concat([base_feats, evoked_feats], sort=False)

    xii_eup_mean = df_cat.groupby('Type').mean()['xii_amp']['eupnea']
    dia_eup_mean = df_cat.groupby('Type').mean()['dia_amp']['eupnea']

    df_cat["xii_amp_normed"] = df_cat["xii_amp"] / xii_eup_mean
    df_cat["dia_amp_normed"] = df_cat["dia_amp"] / dia_eup_mean

    # Plot feature space
    used_feats = ['dia_amp', 'xii_amp', 'dia_auc']
    plt.close('all')
    plt.figure(figsize=(10, 10))
    sns.pairplot(df_cat[used_feats + ['Type', 'PCA0', f'PCA{n}']], hue='Type')
    plt.savefig(os.path.join(p_fig, f'{tag}_feature_space.png'), dpi=300)
    plt.close('all')

    # Plot traces
    f, ax = plt.subplots(nrows=2, sharex=True, figsize=(2, 4))
    try:
        ax[0].plot(xii_shapes_evoked[:, evoked_feats['Type'] == 'eupnea'],
                   'c',
                   alpha=0.1)
    except:
        pass
    ax[0].plot(np.nanmedian(xii_shapes_baseline, axis=1), 'k', lw=2)
    ax[0].plot(xii_shapes_baseline, 'k', alpha=0.05, lw=0.5)
    try:
        ax[1].plot(dia_shapes_evoked[:, evoked_feats['Type'] == 'eupnea'],
                   'c',
                   alpha=0.1)
    except:
        pass
    ax[1].plot(np.nanmedian(dia_shapes_baseline, axis=1), 'k', lw=2)
    ax[1].plot(dia_shapes_baseline, 'k', alpha=0.05, lw=0.5)
    try:
        ax[0].plot(xii_shapes_evoked[:, evoked_feats['Type'] == 'sigh'],
                   'r',
                   alpha=0.3)
        ax[1].plot(dia_shapes_evoked[:, evoked_feats['Type'] == 'sigh'],
                   'r',
                   alpha=0.3)
    except:
        print('No Sighs found')
    plt.tight_layout()
    # ax[0].set_ylim([np.min(xii_shapes_baseline),df_cat['xii_amp'].max()])
    # ax[1].set_ylim([np.min(dia_shapes_baseline),df_cat['dia_amp'].max()])
    plt.savefig(os.path.join(p_fig, f'{tag}_traces.png'), dpi=300)

    plt.close('all')
    plot_stim_aligned(f_list, stim_chan, xii_chan, dia_chan, pre=0.2, post=1)
    plt.savefig(os.path.join(subfig_path, f'{tag}_stim_aligned.png'), dpi=300)
    plt.close('all')

    # Create a shapes dataframe
    df_dia_shapes = pd.DataFrame(dia_shapes)
    df_dia_shapes.index /= sr
    df_dia_shapes.index.name = 'Time'

    df_xii_shapes = pd.DataFrame(xii_shapes)
    df_xii_shapes.index /= sr
    df_xii_shapes.index.name = 'Time'

    # Write data outputs
    print('Writing Data outputs')
    df_xii_shapes.to_csv(os.path.join(p_results, f'{tag}_xii_shapes.csv'))
    df_dia_shapes.to_csv(os.path.join(p_results, f'{tag}_dia_shapes.csv'))

    df_cat['condition'] = condition
    # df_cat = df_cat.sort_values(['sweep_no','absolute_peak_time'])
    df_cat = df_cat.set_index(['sweep_no', 'absolute_peak_time']).sort_index()
    df_cat.to_excel(os.path.join(p_results, f'{tag}_features.xls'))
    df_summary = df_cat.groupby(['evoked', 'Type']).describe()
    df_summary.to_excel(os.path.join(p_results, f'{tag}_summary.xls'))
    plt.close('all')
Ejemplo n.º 5
0
def compute_sigh(fn, xii_chan, dia_chan, tag, start, stop, condition):
    # ===================== ===================== #
    # Initialize parameters #
    # ===================== ===================== #
    pre_win = 0.2
    post_win = 0.3

    today = datetime.datetime.now().isoformat()[:10]
    p_fig = rf'../../figs/{today}-invivo-baseline/{tag}'
    p_results = rf'../../results/{today}classify_sigh_sweeps_invivo_baseline/{tag}'

    # ====================== ==================== #
    # collect bursts
    nstims = 0
    subfig_path = os.path.join(p_fig, 'traces')
    if os.path.isdir(subfig_path):
        print(f'Already finished {tag}')
    else:
        os.makedirs(subfig_path)
    if os.path.isdir(p_results):
        print(f'Already finished {tag}')
    else:
        os.makedirs(p_results)

    all_xii_df = pd.DataFrame()
    all_dia_df = pd.DataFrame()

    # ===================== #
    # Work on file #
    # ===================== #
    dat = pyabf.ABF(fn)
    sr = dat.dataRate
    start_samp = start * sr
    stop_samp = stop * sr
    print(fn)
    xii_id = abf.get_channel_id_by_label(dat, xii_chan)
    dia_id = abf.get_channel_id_by_label(dat, dia_chan)

    # Get integrated diaphragm
    dat.setSweep(0, dia_id)
    tvec = dat.sweepX[start_samp:stop_samp]
    dia = dat.sweepY[start_samp:stop_samp]
    dia = rlab_signal.remove_EKG(dia, sr)
    dia_int = rlab_signal.integrator(dia, sr, span=0.05, acausal=False)
    dia_int_f = dia_int - np.min(dia_int[1000:-1000])
    dia_df = burst_stats_dia(dia_int_f, sr)
    dia_df = get_sigh_idx(dia_df, thresh=6)

    dia_shapes = burst.get_burst_shape(dia_int,
                                       dia_df['on_sec'].values,
                                       dia_df['off_sec'].values,
                                       sr,
                                       pre_win=pre_win,
                                       post_win=post_win)

    # Get prebot
    dat.setSweep(0, xii_id)
    xii_int = dat.sweepY[start_samp:stop_samp]
    xii_shapes = burst.get_burst_shape(xii_int,
                                       dia_df['on_sec'].values,
                                       dia_df['off_sec'].values,
                                       sr,
                                       pre_win=pre_win,
                                       post_win=post_win)
    xii_amp, xii_dur = shapes_to_stats(xii_shapes, sr)
    xii_df = pd.DataFrame()
    xii_df['amp'] = xii_amp
    xii_df['duration_sec'] = xii_dur

    # =================== #
    # Visualize peaks
    # =================== #

    viz_peaks(tvec, xii_int, dia_int, dia_df)
    plt.savefig(os.path.join(subfig_path, f'{tag}_baseline_y.png'), dpi=300)
    plt.close('all')

    viz_peaks(tvec, xii_int, dia_int, dia_df)
    plt.xlim(start + 10, start + 30)
    plt.savefig(os.path.join(subfig_path, f'{tag}_baseline_early.png'),
                dpi=300)
    plt.close('all')

    viz_peaks(tvec, xii_int, dia_int, dia_df)
    plt.xlim(stop - 30, stop - 10)
    plt.savefig(os.path.join(subfig_path, f'{tag}_baseline_late.png'), dpi=300)
    plt.close('all')

    xii_df['file'] = fn
    dia_df['file'] = fn

    xii_df['stim'] = nstims
    dia_df['stim'] = nstims

    feats = pd.DataFrame()
    feats['dia_amp'] = dia_df['amp']
    feats['dia_dur'] = dia_df['duration_sec']
    feats['dia_auc'] = dia_df['auc']
    feats['xii_amp'] = xii_df['amp']
    feats['xii_dur'] = xii_df['duration_sec']
    feats['peak_time'] = dia_df['pk_time'] - dia_df['on_sec']
    feats['peak_time_global'] = dia_df['pk_time']
    feats['post_IBI'] = dia_df['postBI']
    feats['Type'] = dia_df['type']
    feats = feats.reset_index(drop=True)

    # ============================= #
    # REMOVE BAD STUFF #
    # ============================= #

    # Remove gigantic crap
    # keep = feats['dia_amp']<0.5
    # feats = feats[keep]
    # pbt_shapes = pbt_shapes[:,keep]
    # dia_shapes = dia_shapes[:,keep]
    # feats = feats.reset_index(drop=True)
    #
    # keep = feats['pbt_amp']<4
    # feats = feats[keep]
    # pbt_shapes = pbt_shapes[:,keep]
    # dia_shapes = dia_shapes[:,keep]
    # feats = feats.reset_index(drop=True)

    # Get rid of bursts that hit their peak in under 10ms
    # keep = feats['peak_time'] > 0.010
    # feats = feats[keep]
    # xii_shapes = xii_shapes[:,keep]
    # dia_shapes = dia_shapes[:,keep]
    # feats = feats.reset_index(drop=True)

    ## Remove bursts that are all nan
    keep1 = np.all(np.isfinite(xii_shapes), axis=0)
    keep2 = np.all(np.isfinite(dia_shapes), axis=0)
    keep = np.logical_and(keep1, keep2)

    xii_shapes = xii_shapes[:, keep]
    dia_shapes = dia_shapes[:, keep]
    feats = feats[keep]
    feats = feats.reset_index(drop=True)

    keep = np.all(feats.notna(), axis=1).values
    xii_shapes = xii_shapes[:, keep]
    dia_shapes = dia_shapes[:, keep]
    feats = feats[keep].reset_index(drop=True)

    assert (feats.shape[0] == xii_shapes.shape[1])
    assert (feats.shape[0] == dia_shapes.shape[1])

    # =========================== #
    # Perform PCA Decomposition on diaphragm#
    # =========================== #
    n = 3
    dec = sklearn.decomposition.PCA(n)
    aa = dia_shapes - np.min(dia_shapes, axis=0)
    dec.fit(aa.T)
    bb = dia_shapes - np.min(dia_shapes, axis=0)
    PCA_dia = dec.transform(bb.T)[:, :n]

    # =========================== #
    # Perform PCA Decomposition on prebot#
    # =========================== #
    n = 3

    # Then, rerun with that number of PCs
    dec = sklearn.decomposition.PCA(n)
    aa = xii_shapes - np.min(xii_shapes, axis=0)
    dec.fit(aa.T)
    bb = xii_shapes - np.min(xii_shapes, axis=0)
    PCA_pbt = dec.transform(bb.T)[:, :n]

    # combine PCAs
    PCA = np.hstack([PCA_dia, PCA_pbt])

    # add labels to dataframe
    for kk in range(PCA.shape[1]):
        feats[f'PCA{kk}'] = PCA[:, kk]
    feats['condition'] = condition

    xii_eup_mean = feats.groupby('Type').mean()['xii_amp']['eupnea']
    dia_eup_mean = feats.groupby('Type').mean()['dia_amp']['eupnea']

    feats["xii_amp_normed"] = feats["xii_amp"] / xii_eup_mean
    feats["dia_amp_normed"] = feats["dia_amp"] / dia_eup_mean

    # ================= #
    # PLOTS #
    # ================= #

    # Plot feature space
    plt.close('all')
    plt.figure(figsize=(10, 10))

    used_feats = ['dia_amp', 'xii_amp', 'dia_auc']
    sns.pairplot(feats[used_feats + ['Type', 'PCA0', f'PCA{n}']], hue='Type')
    plt.savefig(os.path.join(p_fig, f'{tag}_feature_space.png'), dpi=300)
    plt.close('all')

    # Plot traces
    fig, ax = plt.subplots(nrows=2, sharex=True, figsize=(2, 3))
    ax[0].plot(xii_shapes[:, feats['Type'] == 'eupnea'], 'k', alpha=0.1)
    ax[1].plot(dia_shapes[:, feats['Type'] == 'eupnea'], 'k', alpha=0.1)
    try:
        ax[0].plot(xii_shapes[:, feats['Type'] == 'sigh'], 'r', alpha=0.3)
        ax[1].plot(dia_shapes[:, feats['Type'] == 'sigh'], 'r', alpha=0.3)
    except:
        print('No Sighs found')
    sns.despine()
    plt.tight_layout()
    plt.savefig(os.path.join(p_fig, f'{tag}_traces.png'), dpi=300)

    # ================= #
    # OUTPUT DATA #
    # ================= #

    # Create a shapes dataframe
    df_dia_shapes = pd.DataFrame(dia_shapes)
    df_dia_shapes.index /= sr
    df_dia_shapes.index.name = 'Time'

    df_xii_shapes = pd.DataFrame(xii_shapes)
    df_xii_shapes.index /= sr
    df_xii_shapes.index.name = 'Time'

    # Write data outputs
    print('Writing Data outputs')
    df_xii_shapes.to_csv(os.path.join(p_results, f'{tag}_xii_shapes.csv'))
    df_dia_shapes.to_csv(os.path.join(p_results, f'{tag}_dia_shapes.csv'))

    feats.to_csv(os.path.join(p_results, f'{tag}_features.csv'))
    df_summary = feats.groupby(['Type']).describe()
    df_summary.to_excel(os.path.join(p_results, f'{tag}_summary.xls'))
    df_sigh_interval = feats[feats.Type == 'sigh']['peak_time_global']
    df_sigh_interval.to_csv(os.path.join(p_results, f'{tag}_sigh_times.csv'))
    plt.close('all')