Ejemplo n.º 1
0
def save_to_hdf5(power: AverageTFR):
    if not os.path.exists(power.comment.parent):
        os.makedirs(power.comment.parent)

    power.comment = str(power.comment)
    # replace floats in file name with integers
    floats = re.findall("[-+]?\d*\.\d+", power.comment)
    if floats:
        for num in floats:
            power.comment = power.comment.replace(num, str(int(float(num))))
    file_path_with_extension = f"{power.comment}_power-tfr.h5"

    logger.info(f"Saving power at {file_path_with_extension} ...")
    power.save(fname=str(file_path_with_extension), overwrite=True)
    logger.info("[FINISHED]")
Ejemplo n.º 2
0
def average_power_into_frequency_bands(power: AverageTFR,
                                       config: dict) -> AverageTFR:
    """
    Computes average power in specific frequency ranges defined in config
    Parameters
    ----------
    power
    config

    Returns
    -------

    """
    band_values = config.values()
    band_power_data_arr = []

    for band in band_values:
        band_filter = np.logical_and(power.freqs >= float(band[0]),
                                     power.freqs < float(band[1]))
        if int(max(power.freqs)) == int(max(band_values)[1]):
            band_filter[-1] = True
        band_data = power.data[:, band_filter, :].mean(axis=1)
        band_power_data_arr.append(band_data[np.newaxis, :])

    band_power_data = np.concatenate(band_power_data_arr, axis=0)
    band_power_data = np.transpose(band_power_data, (1, 0, 2))

    band_freqs = np.asarray([band[0] for band in band_values])

    band_power = power.copy()
    band_power.data = band_power_data
    band_power.freqs = band_freqs
    band_power.comment = power.comment + "_band_average"

    return band_power
Ejemplo n.º 3
0
    # Roll covariance, csp and lda over time
    for t, w_time in enumerate(centered_w_times):

        # Center the min and max of the window
        w_tmin = w_time - w_size / 2.
        w_tmax = w_time + w_size / 2.

        # Crop data into time-window of interest
        X = epochs.copy().crop(w_tmin, w_tmax).get_data()

        # Save mean scores over folds for each frequency and time window
        scores[freq, t] = np.mean(cross_val_score(estimator=clf,
                                                  X=X,
                                                  y=y,
                                                  cv=cv,
                                                  n_jobs=1),
                                  axis=0)

###############################################################################
# Plot results

# Set up time frequency object
av_tfr = AverageTFR(create_info(['freq'], sfreq), scores[np.newaxis, :],
                    centered_w_times, freqs[1:], 1)

chance = np.mean(y)  # set chance level to white in the plot
av_tfr.plot([0],
            vmin=chance,
            title="Time-Frequency Decoding Scores",
            cmap=plt.cm.Reds)
Ejemplo n.º 4
0
# Define wavelet frequencies and number of cycles
cwt_freqs = np.arange(7, 30, 2)
cwt_n_cycles = cwt_freqs / 7.

# Run the connectivity analysis using 2 parallel jobs
sfreq = raw.info['sfreq']  # the sampling frequency
con, freqs, times, _, _ = spectral_connectivity(epochs,
                                                indices=indices,
                                                method='wpli2_debiased',
                                                mode='cwt_morlet',
                                                sfreq=sfreq,
                                                cwt_freqs=cwt_freqs,
                                                cwt_n_cycles=cwt_n_cycles,
                                                n_jobs=1)

# Mark the seed channel with a value of 1.0, so we can see it in the plot
con[np.where(indices[1] == seed)] = 1.0

# Show topography of connectivity from seed
title = 'WPLI2 - Visual - Seed %s' % seed_ch

layout = mne.find_layout(epochs.info, 'meg')  # use full layout

tfr = AverageTFR(epochs.info, con, times, freqs, len(epochs))
tfr.plot_topo(fig_facecolor='w', font_color='k', border='k')

# %%
# References
# ----------
# .. footbibliography::
# Use 'MEG 2343' as seed
seed_ch = 'MEG 2343'
picks_ch_names = [raw.ch_names[i] for i in picks]

# Create seed-target indices for connectivity computation
seed = picks_ch_names.index(seed_ch)
targets = np.arange(len(picks))
indices = seed_target_indices(seed, targets)

# Define wavelet frequencies and number of cycles
cwt_freqs = np.arange(7, 30, 2)
cwt_n_cycles = cwt_freqs / 7.

# Run the connectivity analysis using 2 parallel jobs
sfreq = raw.info['sfreq']  # the sampling frequency
con, freqs, times, _, _ = spectral_connectivity(
    epochs, indices=indices,
    method='wpli2_debiased', mode='cwt_morlet', sfreq=sfreq,
    cwt_freqs=cwt_freqs, cwt_n_cycles=cwt_n_cycles, n_jobs=1)

# Mark the seed channel with a value of 1.0, so we can see it in the plot
con[np.where(indices[1] == seed)] = 1.0

# Show topography of connectivity from seed
title = 'WPLI2 - Visual - Seed %s' % seed_ch

layout = mne.find_layout(epochs.info, 'meg')  # use full layout

tfr = AverageTFR(epochs.info, con, times, freqs, len(epochs))
tfr.plot_topo(fig_facecolor='w', font_color='k', border='k')
Ejemplo n.º 6
0
    epochs = Epochs(raw_filter, events, event_id, tmin - w_size, tmax + w_size,
                    proj=False, baseline=None, preload=True)
    epochs.drop_bad()
    y = le.fit_transform(epochs.events[:, 2])

    # Roll covariance, csp and lda over time
    for t, w_time in enumerate(centered_w_times):

        # Center the min and max of the window
        w_tmin = w_time - w_size / 2.
        w_tmax = w_time + w_size / 2.

        # Crop data into time-window of interest
        X = epochs.copy().crop(w_tmin, w_tmax).get_data()

        # Save mean scores over folds for each frequency and time window
        tf_scores[freq, t] = np.mean(cross_val_score(estimator=clf, X=X, y=y,
                                                     scoring='roc_auc', cv=cv,
                                                     n_jobs=1), axis=0)

###############################################################################
# Plot time-frequency results

# Set up time frequency object
av_tfr = AverageTFR(create_info(['freq'], sfreq), tf_scores[np.newaxis, :],
                    centered_w_times, freqs[1:], 1)

chance = np.mean(y)  # set chance level to white in the plot
av_tfr.plot([0], vmin=chance, title="Time-Frequency Decoding Scores",
            cmap=plt.cm.Reds)
Ejemplo n.º 7
0
# Connectivity
for ix_c, c in enumerate([exp_sup_lon, exp_sup_sho]):
    c.info['subject_info'] = 'P14'
    c.info['cond'] = conds[ix_c]
    calc_wpli_over_time(c)

dats = list()
for c in conds:
    dat = np.load(
        op.join(study_path, 'results', 'wpli', 'P14_{}_wpli.npz'.format(c)))
    dats.append(dat)

info = dat['info'].item()

for ix_c, c in enumerate(conds):
    tfr = AverageTFR(info, dats[ix_c]['con'][1, :, :, :], dat['times'],
                     dat['freqs'], 1)
    tfr.plot(picks=[83], vmin=-1, vmax=1, cmap='viridis', title=c)

# TF ROI
if ref == 'avg':
    rois = {
        'HP_r': np.array(raw.ch_names)[picks[:4]],
        'HP_l': np.array(raw.ch_names)[picks[4:]]
    }  # :4
elif ref == 'bip':
    rois = {
        'HP_l': np.array(raw.ch_names)[picks[:2]],
        'HP_r': np.array(raw.ch_names)[picks[2:]]
    }

roi_fig, axes = plt.subplots(len(rois), len(conds), sharey=True, sharex=True)
Ejemplo n.º 8
0
def load_patient_tfr(deriv_path, subject, band, task=None, verbose=True):
    from mne.time_frequency import read_tfrs, AverageTFR

    if task is not None:
        search_str = f"*sub-{subject}_*task-{task}*.h5"
    else:
        search_str = f"*sub-{subject}_*.h5"

    deriv_files = [f.as_posix() for f in Path(deriv_path).rglob(search_str)]

    if band == "delta":
        fmin, fmax = 0.5, 5
    elif band == "theta":
        fmin, fmax = 5, 10
    elif band == "alpha":
        fmin, fmax = 10, 16
    elif band == "beta":
        fmin, fmax = 16, 30
    elif band == "gamma":
        fmin, fmax = 30, 90
    elif band == "highgamma":
        fmin, fmax = 90, 300
    else:
        raise ValueError("kwarg 'band' can only be prespecified set of values")

    # print(deriv_files)
    patient_tfrs = []
    for deriv_fpath in deriv_files:
        # print(f"Loading {deriv_fpath}")
        avg_tfr = read_tfrs(deriv_fpath)[0]

        # print(avg_tfr.freqs)
        # only obtain the Band TFR for that subject
        freq_inds = np.where((avg_tfr.freqs >= fmin) & (avg_tfr.freqs < fmax))[0]
        # print(freq_inds)
        band_data = np.mean(avg_tfr.data[:, freq_inds, :], axis=1, keepdims=True)
        band_tfr = AverageTFR(
            avg_tfr.info,
            data=band_data,
            freqs=[fmin],
            nave=1,
            times=avg_tfr.times,
            verbose=0,
        )

        json_fpath = deriv_fpath.replace(".h5", ".json")
        with open(json_fpath, "r") as fin:
            sidecar_json = json.load(fin)

        # obtain the event IDs for the markers of interest
        sz_onset_id = sidecar_json.get("sz_onset_event_id", None)
        sz_offset_id = sidecar_json.get("sz_offset_event_id")
        clin_onset_id = sidecar_json.get("clin_onset_event_id")

        # events array
        events = sidecar_json["events"]

        # obtain onset/offset event markers
        sz_onset_win = _get_onset_event_id(events, sz_onset_id)
        sz_offset_win = _get_onset_event_id(events, sz_offset_id)
        clin_onset_win = _get_onset_event_id(events, clin_onset_id)

        # set those windows
        sidecar_json["sz_onset_win"] = sz_onset_win
        sidecar_json["sz_offset_win"] = sz_offset_win
        sidecar_json["clin_onset_win"] = clin_onset_win
        sidecar_json["freq_band"] = (fmin, fmax)

        # create a Result object
        band_tfr = Result(
            Normalize.compute_fragilitymetric(band_data.squeeze(), invert=True),
            info=avg_tfr.info,
            metadata=sidecar_json,
        )
        # band_tfr.metadata.save(json_fpath)

        if np.isnan(band_data).any():
            print(f"Skipping {deriv_fpath} due to nans")
            continue

        if sz_onset_win is None:
            print(f"Skipping {deriv_fpath}")
            continue

        patient_tfrs.append(band_tfr)
    return patient_tfrs
Ejemplo n.º 9
0
def wpli_analysis_time(subjects, log):
    conds = ['lon', 'sho']
    roi_cons = {}
    roi = rois['f']
    spatial_con = mne.channels.read_ch_connectivity('biosemi64')

    for ix_s, s in enumerate(subjects):
        print('subject {} of {}' .format(ix_s+1, len(subjects)))
        for ix_c, c in enumerate(conds):
            filename = op.join(study_path, 'results', 'wpli', 'over_time', '{}_{}_wpli.npz' .format(s, c))
            dat = np.load(filename)
            info = dat['info'].item()

            # Create results matrices
            if ix_s == 0:
                roi_cons[c] = np.empty((len(subjects), dat['con'].shape[0], dat['con'].shape[-2], dat['con'].shape[-1]))

            # Get ROI connectivity
            # roi_ixs = [ix for ix, ch in enumerate(info['ch_names']) if ch in roi]
            # roi_con = np.mean(dat['con'][roi_ixs, :, :, :], axis=0)
            roi_ixs = 33
            roi_con = dat['con'][roi_ixs, :, :, :]
            roi_con[roi_ixs, :, :] = 1.0
            roi_cons[c][ix_s, :, :, :] = roi_con.copy()

    avg_con = [np.mean(roi_cons[c], axis=0) for c in conds]

    for ix_c, c in enumerate(conds):
        tfr = AverageTFR(info, avg_con[ix_c], dat['times'], dat['freqs'], len(subjects))
        tfr.plot_topo(fig_facecolor='w', font_color='k', border='k', vmin=0, vmax=0.5, cmap='viridis', title=c)

    s = 12
    for c in conds:
        tfr = AverageTFR(info, roi_cons[c][s, :, :, :], dat['times'], dat['freqs'], len(subjects))
        tfr.plot_topo(fig_facecolor='w', font_color='k', border='k', vmin=0, vmax=1, cmap='viridis', title=c)

    # Stats
    test_con = [roi_cons[c][:, 19, :, :] for c in conds]

    #threshold = None
    threshold = dict(start=0, step=0.2)
    T_obs, clusters, cluster_p_values, H0 = \
        permutation_cluster_test([test_con[0], test_con[1]],
                                 n_permutations=1000, threshold=threshold, tail=0)

    times = dat['times']
    times *= 1e3
    freqs = dat['freqs']

    fig, ax = plt.subplots(1)
    T_obs_plot = np.nan * np.ones_like(T_obs)
    for c, p_val in zip(clusters, cluster_p_values):
        if p_val <= 0.05:
            T_obs_plot[c] = T_obs[c]

    ax.imshow(T_obs,
               extent=[times[0], times[-1], freqs[0], freqs[-1]],
               aspect='auto', origin='lower', cmap='gray')
    ax.imshow(T_obs_plot,
               extent=[times[0], times[-1], freqs[0], freqs[-1]],
               aspect='auto', origin='lower', cmap='RdBu_r')

    plt.xlabel('Time (ms)')
    plt.ylabel('Frequency (Hz)')
    plt.title('ROI Connectivity')