def get_trial_info(subject, inputs, outputs, recompute, logpath):
    """Extract trial information from Presentation log files and BrainVision
    marker files.

    1. trial type {high threat, low threat}
    2. start baseline (seconds)
    3. cue (seconds)
    4. stimulus (seconds)
    """
    save_path = Path(individualize_path(outputs["save_path"], subject, expand_name=True))
    if save_path.exists() and not recompute:    # only recompute if requested
        print(f"Not re-computing {save_path}")
        return
    log_path = next(Path(".").glob(individualize_path(inputs["log_path"], subject)))
    marker_path = next(Path(".").glob(individualize_path(inputs["marker_path"], subject)))

    df_log = pd.read_csv(log_path, sep="\t", usecols=["CSI", "shock"])
    markers = read_annotations(marker_path, ECG_SFREQ_ORIGINAL)
    cues = markers.onset[markers.description == "Stimulus/S  2"]
    assert len(cues) == df_log.shape[0], ("Unequal number of trials between"
                                          " BrainVision and Presentation files"
                                          f" for participant {subject}.")

    trial_types = df_log["shock"]
    stimuli = cues + df_log["CSI"] / 1000
    baselines = cues - 1    # baseline starts 1 second before cue

    df = pd.DataFrame({"threat": trial_types, "baseline": baselines,
                       "cue": cues, "stimulus": stimuli})
    df = df[(df["stimulus"] - df["cue"]) >= 6]    # exclude trials with anticipation windows shorter than 6 seconds
    df.to_csv(save_path, sep="\t", header=True, index=False, float_format="%.4f")
Пример #2
0
def get_period_ecg(subject, inputs, outputs, recompute, logpath):
    """Compute continuous heart period.

    1. Compute inter-beat-intervals
    2. Interpolate inter-beat-intervals to time series sampled at ECG_PERIOD_SFREQ Hz.
    """
    save_path = Path(
        individualize_path(outputs["save_path"], subject, expand_name=True))
    if save_path.exists() and not recompute:  # only recompute if requested
        print(f"Not re-computing {save_path}")
        return
    physio_path = next(
        Path(".").glob(individualize_path(inputs["physio_path"], subject)))

    peaks = np.ravel(pd.read_csv(physio_path, sep="\t", header=None))

    # Compute period in milliseconds.
    period = np.ediff1d(
        peaks, to_begin=0
    ) / ECG_SFREQ_DECIMATED * 1000  # make sure period has same number of elements as peaks
    period[0] = period[
        1]  # make sure that the first element has a realistic value

    # Interpolate instantaneous heart period at ECG_PERIOD_SFREQ Hz. Interpolate up until the
    # last R-peak.
    duration = peaks[-1] / ECG_SFREQ_DECIMATED  # in seconds
    nsamples = int(np.rint(duration * ECG_PERIOD_SFREQ))
    period_interpolated = interpolate_signal(peaks, period, nsamples)

    pd.Series(period_interpolated).to_csv(save_path,
                                          sep="\t",
                                          header=False,
                                          index=False,
                                          float_format="%.6f")

    if not logpath:
        return

    fig, ax = plt.subplots(nrows=1, ncols=1, sharex=True)
    sec = np.linspace(0, duration, peaks[-1])
    ax.vlines(sec[peaks[:-1]],
              ymin=min(period),
              ymax=max(period),
              label="R-peaks",
              alpha=.3,
              colors="r")
    sec = np.linspace(0, duration, nsamples)
    ax.plot(sec,
            period_interpolated,
            label=("period interpolated between R-peaks at"
                   f" {ECG_PERIOD_SFREQ}Hz"))
    ax.set_xlabel("seconds")
    ax.legend(loc="upper right")

    fig.savefig(logpath, dpi=200)
    plt.close(fig)
Пример #3
0
def preprocess_ecg(subject, inputs, outputs, recompute, logpath):
    """Preprocessing of raw ECG from BrainVision files.

    1. downsample from 2500Hz to 500Hz
    2. flip inverted signal
    """
    save_path = Path(
        individualize_path(outputs["save_path"], subject, expand_name=True))
    if save_path.exists() and not recompute:  # only recompute if requested
        print(f"Not re-computing {save_path}")
        return
    physio_path = next(
        Path(".").glob(individualize_path(inputs["physio_path"], subject)))

    raw = mne.io.read_raw_brainvision(physio_path,
                                      preload=False,
                                      verbose="error")
    ecg = raw.get_data(picks=ECG_CHANNELS).ravel()
    sfreq = raw.info["sfreq"]
    assert sfreq == ECG_SFREQ_ORIGINAL, (f"Sampling frequency {sfreq} doesn't"
                                         " match expected sampling frequency"
                                         f" {ECG_SFREQ_ORIGINAL}.")

    # Decimate the ECG from original sampling rate to 500 HZ.
    decimation_factor = int(np.floor(sfreq / ECG_SFREQ_DECIMATED))
    ecg_decimated = decimate_signal(ecg, decimation_factor)
    # Flip the inverted ECG signal.
    ecg_inverted = invert_signal(ecg_decimated)

    pd.Series(ecg_inverted).to_csv(save_path,
                                   sep="\t",
                                   header=False,
                                   index=False,
                                   float_format="%.4f")

    if not logpath:
        return

    fig, (ax0, ax1) = plt.subplots(nrows=2, ncols=1, sharex=True)
    sec = np.linspace(0, len(ecg) / sfreq, len(ecg))
    ax0.plot(sec, ecg, label=f"original ({sfreq}Hz)")
    ax0.set_xlabel("seconds")
    ax0.legend(loc="upper right")
    sec = np.linspace(0,
                      len(ecg_decimated) / ECG_SFREQ_DECIMATED,
                      len(ecg_decimated))
    ax1.plot(sec,
             ecg_decimated,
             label=f"downsampled ({ECG_SFREQ_DECIMATED}Hz)")
    ax1.plot(sec, ecg_inverted, label=f"flipped ({ECG_SFREQ_DECIMATED}Hz)")
    ax1.set_xlabel("seconds")
    ax1.legend(loc="upper right")

    fig.savefig(logpath, dpi=200)
    plt.close(fig)
Пример #4
0
def get_peaks_ecg(subject, inputs, outputs, recompute, logpath):
    """Detect R-peaks in ECG.

    1. Detect R-peaks
    2. autocorrect artifacts in R-peaks detection.
    """
    save_path = Path(
        individualize_path(outputs["save_path"], subject, expand_name=True))
    if save_path.exists() and not recompute:  # only recompute if requested
        print(f"Not re-computing {save_path}")
        return
    physio_path = next(
        Path(".").glob(individualize_path(inputs["physio_path"], subject)))

    ecg = np.ravel(pd.read_csv(physio_path, sep="\t", header=None))
    # Detect R-peaks.
    peaks = ecg_peaks(ecg, ECG_SFREQ_DECIMATED)
    # Correct artifacts in peak detection.
    peaks_corrected = correct_peaks(peaks, ECG_SFREQ_DECIMATED, iterative=True)
    # Save peaks as samples.
    pd.Series(peaks_corrected).to_csv(save_path,
                                      sep="\t",
                                      header=False,
                                      index=False,
                                      float_format="%.4f")

    if not logpath:
        return

    fig, ax = plt.subplots(nrows=1, ncols=1)
    sec = np.linspace(0, len(ecg) / ECG_SFREQ_DECIMATED, len(ecg))
    ax.plot(sec, ecg)
    ax.scatter(sec[peaks],
               ecg[peaks],
               zorder=3,
               c="r",
               marker="+",
               s=300,
               label="uncorrected R-peaks")
    ax.scatter(sec[peaks_corrected],
               ecg[peaks_corrected],
               zorder=4,
               c="g",
               marker="x",
               s=300,
               label="corrected R-peaks")
    ax.set_xlabel("seconds")
    ax.legend(loc="upper right")

    fig.savefig(logpath, dpi=200)
    plt.close(fig)
Пример #5
0
def remove_outliers_period_ecg(subject, inputs, outputs, recompute, logpath):
    """Remove outliers from heart period series."""
    save_path = Path(
        individualize_path(outputs["save_path"], subject, expand_name=True))
    if save_path.exists() and not recompute:  # only recompute if requested
        print(f"Not re-computing {save_path}")
        return
    physio_path = next(
        Path(".").glob(individualize_path(inputs["physio_path"], subject)))

    period = np.ravel(pd.read_csv(physio_path, sep="\t", header=None))

    # Remove outliers based on absolute cutoffs. Those cutoffs have been chosen
    # based on the visual inspection of all heart period time series data. The
    # cutoffs have been set such that they preserve the data as much as possible
    # (when in doubt don't flag a period as outlier).
    min_period = 60000 / HR_MAX
    max_period = 60000 / HR_MIN
    abs_outliers = np.where((period < min_period) | (period > max_period))

    # Median filter period with absolute outliers removed.
    period_without_abs_outliers = period.copy()
    period_without_abs_outliers[abs_outliers] = np.median(period)
    kernel = int(np.rint(ECG_PERIOD_SFREQ * RUNNING_MEDIAN_KERNEL_SIZE))
    if not kernel % 2: kernel += 1
    period_trend = median_filter(period_without_abs_outliers, size=kernel)

    # Remove outliers based on relative cutoffs.
    rel_threshold = MAD_THRESHOLD_MULTIPLIER * median_absolute_deviation(
        period_without_abs_outliers)
    upper_rel_threshold = period_trend + rel_threshold
    lower_rel_threshold = period_trend - rel_threshold
    period_masked_outliers = ma.masked_where(
        (period < lower_rel_threshold) | (period > upper_rel_threshold),
        period)
    assert period_masked_outliers.size == period.size

    pd.Series(ma.filled(period_masked_outliers,
                        fill_value=np.nan)).to_csv(save_path,
                                                   sep="\t",
                                                   header=False,
                                                   index=False,
                                                   float_format="%.6f",
                                                   na_rep="NaN")

    if not logpath:
        return

    fig, (ax0, ax1) = plt.subplots(nrows=2, ncols=1, sharex=True)
    sec = np.linspace(0, period.size / ECG_PERIOD_SFREQ, period.size)
    ax0.plot(sec, period)
    ax0.fill_between(sec,
                     upper_rel_threshold,
                     lower_rel_threshold,
                     color="lime",
                     alpha=.5)
    ax0.plot(sec, period_trend, c="lime")
    ax1.vlines(sec[period_masked_outliers.mask],
               period_masked_outliers.min(),
               period_masked_outliers.max(),
               colors="fuchsia")
    ax1.plot(sec, period_masked_outliers)
    ax1.set_xlabel("seconds")

    fig.savefig(logpath, dpi=200)
    plt.close(fig)
Пример #6
0
def get_sway_bb(subject, inputs, outputs, recompute, logpath):
    """Compute body sway metrics.

    1. anterior-posterior (back-forth) sway
    2. medio-lateral (left-right) sway
    3. sway radius
    4. total sway path
    """
    save_path = Path(
        individualize_path(outputs["save_path"], subject, expand_name=True))
    if save_path.exists() and not recompute:  # only recompute if requested
        print(f"Not re-computing {save_path}")
        return
    physio_path = next(
        Path(".").glob(individualize_path(inputs["physio_path"], subject)))

    cop = pd.read_csv(physio_path, sep="\t", header=0)

    n_samples = int(
        np.rint(BB_MOVING_WINDOW *
                BB_SFREQ_DECIMATED))  # width of rolling window in samples

    # Compute body sway.
    ap_sway = cop.loc[:, "ap_filt"].rolling(window=n_samples,
                                            min_periods=1,
                                            center=True).std()
    ml_sway = cop.loc[:, "ml_filt"].rolling(window=n_samples,
                                            min_periods=1,
                                            center=True).std()

    # Compute moving average of the center-of-pressure's radial displacement.
    cop_avg = cop.rolling(window=n_samples, min_periods=1, center=True).mean()
    cop_demeaned = cop - cop_avg

    radius = cop_demeaned.loc[:, "ap_filt"].combine(
        cop_demeaned.loc[:, "ml_filt"], cop_radius)
    radius_avg = radius.rolling(window=n_samples, min_periods=1,
                                center=True).mean()

    # Compute sway path.
    ap_path = np.ediff1d(cop.loc[:, "ap_filt"], to_begin=0)**2
    ml_path = np.ediff1d(cop.loc[:, "ml_filt"], to_begin=0)**2
    total_path = np.sqrt(ap_path + ml_path)

    pd.DataFrame({
        "ap_sway": ap_sway,
        "ml_sway": ml_sway,
        "radius": radius_avg,
        "path": total_path
    }).to_csv(save_path,
              sep="\t",
              header=True,
              index=False,
              float_format="%.4f")  # NaNs are saved as empty strings

    if not logpath:
        return

    sec = cop.index / BB_SFREQ_DECIMATED
    fig0, ax = plt.subplots()
    ax.set_title(f"moving window of {BB_MOVING_WINDOW} seconds")
    ax.set_xlabel("seconds")
    ax.set_ylabel("sway (mm)")
    ax.plot(sec, ap_sway, label="anterior-posterior sway")
    ax.plot(sec, ml_sway, label="medio-lateral sway")
    ax.legend(loc="upper right")

    fig1, (ax0, ax1) = plt.subplots(nrows=2, ncols=1, sharex=True)
    ax0.set_title("COP")
    ax1.set_title(
        f"COP demeaned with moving window of {BB_MOVING_WINDOW} seconds")

    cop.set_index(sec).plot(ax=ax0)
    cop_demeaned.set_index(sec).plot(ax=ax0)

    fig2, ax = plt.subplots()
    ax.plot(sec, radius, label="radial displacement of COP")
    ax.plot(sec,
            radius_avg,
            label="radial displacement of COP averaged over"
            f" moving window of {BB_MOVING_WINDOW} seconds")
    ax.legend(loc="upper right")

    fig3, ax = plt.subplots()
    ax.set_xlabel("seconds")
    ax.set_ylabel("mm")
    ax.set_title("Sway path length")
    ax.plot(sec, total_path)

    fig0.savefig(logpath.with_name(logpath.name + "_fig0"), dpi=200)
    plt.close(fig0)
    fig1.savefig(logpath.with_name(logpath.name + "_fig1"), dpi=200)
    plt.close(fig1)
    fig2.savefig(logpath.with_name(logpath.name + "_fig2"), dpi=200)
    plt.close(fig2)
    fig3.savefig(logpath.with_name(logpath.name + "_fig3"), dpi=200)
    plt.close(fig3)
Пример #7
0
def preprocess_bb(subject, inputs, outputs, recompute, logpath):
    """Preprocessing of raw balance board channels from BrainVision files.

    1. downsample from 2500Hz to 32Hz
    2. transform to millimeter unit (board displacement)
    """
    save_path = Path(
        individualize_path(outputs["save_path"], subject, expand_name=True))
    if save_path.exists() and not recompute:  # only recompute if requested
        print(f"Not re-computing {save_path}")
        return
    physio_path = next(
        Path(".").glob(individualize_path(inputs["physio_path"], subject)))

    raw = mne.io.read_raw_brainvision(physio_path,
                                      preload=False,
                                      verbose="error")
    bb = raw.get_data(picks=BB_CHANNELS)
    sfreq = raw.info["sfreq"]

    assert sfreq == BB_SFREQ_ORIGINAL, (f"Sampling frequency {sfreq} doesn't"
                                        " match expected sampling frequency"
                                        f" {BB_SFREQ_ORIGINAL}.")

    # Decimate the four balance board channels from original sampling rate to
    # BB_SFREQ_DECIMATED HZ. Note that MNE's raw.apply_function() cannot be used since it
    # requires the preservation of the original sampling frequency.
    decimation_factor = int(np.floor(sfreq / BB_SFREQ_DECIMATED))
    bb_decimated = decimate_signal(bb, decimation_factor)

    # Assuming that the participant has been off the board at some time,
    # calculate the empty board value: For each channel, take the mean of a
    # consecutive chunk of data of at least 10 seconds duration that is below
    # the minimum value + std.
    bb_mins = bb_decimated.min(axis=1) + bb_decimated.std(axis=1)
    bb_minsconsecutive = np.zeros((bb_mins.size, 2)).astype(int)
    bb_empty = np.zeros(bb_mins.size)
    min_duration = int(np.ceil(BB_SFREQ_DECIMATED * BB_MIN_EMPTY))

    for i in range(bb_mins.size):

        begs, ends, n = consecutive_samples(bb_decimated[i, :],
                                            lambda x: x < bb_mins[i],
                                            min_duration)
        assert begs.size > 0, (f"Did not find {BB_MIN_EMPTY} consecutive"
                               " seconds of empty board values.")
        # Find longest chunk and save its beginning and end.
        longest = n.argmax()
        beg = begs[longest]
        end = ends[longest]
        bb_empty[i] = bb_decimated[i, beg:end].mean()
        bb_minsconsecutive[i, 0] = beg
        bb_minsconsecutive[i, 1] = end

    # Calculate weight of the participant.
    bb_chansum = np.sum(bb_decimated,
                        axis=0)  # collapse sensors across time axis
    bb_chansum_empty = bb_empty.sum()
    bb_subjweight = np.median(bb_chansum) - bb_chansum_empty

    assert bb_subjweight > bb_chansum_empty, (f"Subject {bb_subjweight} is not"
                                              " heavier than empty board"
                                              f" ({bb_chansum_empty}).")

    # Transform sensor data to millimeter unit.
    bb_mm = np.subtract(bb_decimated, bb_empty.reshape(-1, 1))
    bb_mm = bb_mm / bb_subjweight  # scale by subject weight
    bb_mm = bb_mm * (BB_BOARDLENGTH / 2)  # express in mm

    pd.DataFrame(bb_mm).T.to_csv(
        save_path,
        sep="\t",
        header=[
            "BB1", "BB2", "BB3", "BB4"
        ],  # transpose to change from channels as rows to channels as columns (preserves ordering of channels)
        index=False,
        float_format="%.4f")

    if not logpath:
        return

    fig, (ax0, ax1, ax2) = plt.subplots(nrows=3, ncols=1, sharex=True)
    sec = np.linspace(0, bb_decimated.shape[1] / BB_SFREQ_DECIMATED,
                      bb_decimated.shape[1])
    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
    channames = ["PR", "PL", "AL", "AR"]

    for chan in range(bb_decimated.shape[0]):
        color = colors[chan]
        channame = channames[chan]
        ax0.plot(sec, bb_decimated[chan, :], c=color, label=f"{channame}")
        ax0.axvspan(xmin=sec[bb_minsconsecutive[chan, 0]],
                    xmax=sec[bb_minsconsecutive[chan, 1]],
                    ymin=bb_decimated[chan].min(),
                    ymax=bb_decimated[chan].max(),
                    color=color,
                    alpha=.2,
                    label="empty")

    ax0.hlines(y=bb_empty,
               xmin=0,
               xmax=sec[-1],
               colors=colors[:bb_mins.size],
               linestyles="dotted",
               label="empty")
    ax0.legend(loc="upper right")

    ax1.plot(sec, bb_chansum)
    ax1.axhline(y=bb_subjweight,
                c="r",
                label="subject weight minus board weight")
    ax1.legend(loc="upper right")
    ax1.set_xlabel("seconds")

    for chan in range(bb_mm.shape[0]):
        color = colors[chan]
        channame = channames[chan]
        ax2.plot(sec, bb_mm[chan, :], c=color, label=f"{channame}")
    ax2.legend(loc="upper right")
    ax2.set_xlabel("seconds")
    ax2.set_ylabel("millimeters")

    fig.savefig(logpath, dpi=200)
    plt.close(fig)
Пример #8
0
def get_cop_bb(subject, inputs, outputs, recompute, logpath):
    """Compute center of pressure time series.

    1. combine preprocessed balance board channels to time series of anterior-
    posterior (forth-back) displacement and medio-lateral (left-right)
    discplacement
    2. Filter displacement time series
    """
    save_path = Path(
        individualize_path(outputs["save_path"], subject, expand_name=True))
    if save_path.exists() and not recompute:  # only recompute if requested
        print(f"Not re-computing {save_path}")
        return
    physio_path = next(
        Path(".").glob(individualize_path(inputs["physio_path"], subject)))

    bb = pd.read_csv(physio_path, sep="\t", header=0).to_numpy()

    ap = (bb[:, 2] + bb[:, 3]) - (bb[:, 0] + bb[:, 1]
                                  )  # anterior-posterior displacement
    ml = (bb[:, 0] + bb[:, 3]) - (bb[:, 1] + bb[:, 2]
                                  )  # medio-lateral displacement

    ap_filt = butter_bandpass_filter(ap, BB_FILTER_CUTOFFS[0],
                                     BB_FILTER_CUTOFFS[1], BB_SFREQ_DECIMATED)
    ml_filt = butter_bandpass_filter(ml, BB_FILTER_CUTOFFS[0],
                                     BB_FILTER_CUTOFFS[1], BB_SFREQ_DECIMATED)

    pd.DataFrame({
        "ap_filt": ap_filt,
        "ml_filt": ml_filt
    }).to_csv(save_path,
              sep="\t",
              header=True,
              index=False,
              float_format="%.4f")

    if not logpath:
        return

    sec = np.linspace(0, bb.shape[0] / BB_SFREQ_DECIMATED, bb.shape[0])
    fig0, (ax0, ax1) = plt.subplots(nrows=2, ncols=1, sharex=True)
    ax0.plot(sec, ap, label="anterior-posterior displacement")
    ax0.plot(sec, ap_filt, label="filtered anterior-posterior displacement")
    ax0.set_ylabel("millimeter")
    ax0.legend(loc="upper right")
    ax1.plot(sec, ml, label="medio-lateral displacement")
    ax1.plot(sec, ml_filt, label="filtered medio-lateral displacement")
    ax1.set_ylabel("millimeter")
    ax1.set_xlabel("seconds")
    ax1.legend(loc="upper right")

    fig1, ax = plt.subplots()
    ax.plot(ap_filt, ml_filt)
    ax.set_xlabel("anterior-posterior displacenment (mm)")
    ax.set_ylabel("medio-lateral displacenment (mm)")

    fig0.savefig(logpath.with_name(logpath.name + "_fig0"), dpi=200)
    plt.close(fig0)
    fig1.savefig(logpath.with_name(logpath.name + "_fig1"), dpi=200)
    plt.close(fig1)