Example #1
0
def update_metadata_epochs_and_save_epochs(subject):
    """
    This function updates the metadata fields for the epochs such that they contain all the useful information for
    the complexity and surprise regressions.
    """

    # update the metadata for the non-clean epochs by adding the surprise computed for an observer that has 100 items in memory.
    metadata_notclean = TP_funcs.from_epochs_to_surprise(subject, [100])
    epochs_notclean, fname = epoching_funcs.load_epochs_items(
        subject, cleaned=False, return_fname=True)

    # load the metadata for the non-cleaned epochs, remove the bad ones, and this becomes the metadata for the cleaned epochs
    epochs_clean, fname_clean = epoching_funcs.load_epochs_items(
        subject, cleaned=True, return_fname=True)
    epochs_clean.get_data()
    # ============ build the repeatAlter and the surprise 100 for n+1 ==================
    # 1 - update the full epochs (not_clean) metadata with the new fields
    RepeatAlternp1_notclean = metadata_notclean["RepeatAlter"].values[
        1:].tolist()
    RepeatAlternp1_notclean.append(np.nan)
    Surprisenp1_notclean = metadata_notclean["surprise_100"].values[1:].tolist(
    )
    Surprisenp1_notclean.append(np.nan)
    metadata_notclean = metadata_notclean.assign(Intercept=1)
    metadata_notclean = metadata_notclean.assign(
        RepeatAlternp1=RepeatAlternp1_notclean)
    metadata_notclean = metadata_notclean.assign(
        Surprisenp1=Surprisenp1_notclean)
    epochs_notclean.metadata = metadata_notclean
    epochs_notclean.save(fname, overwrite=True)

    # 2 - subselect only the good epochs indices to filter the metadata
    if subject == 'sub16-ma_190185':
        # in the case of sub16, no epochs are removed in the process of cleaning
        metadata_clean = metadata_notclean
    else:
        good_idx = [
            len(epochs_clean.drop_log[i]) == 0
            for i in range(len(epochs_clean.drop_log))
        ]
        where_good = np.where(good_idx)[0]
        RepeatAlternp1 = np.asarray(RepeatAlternp1_notclean)[where_good]
        Surprisenp1 = np.asarray(Surprisenp1_notclean)[where_good]
        metadata_clean = metadata_notclean[good_idx]

        metadata_clean = metadata_clean.assign(
            Intercept=1)  # Add an intercept for later
        metadata_clean = metadata_clean.assign(RepeatAlternp1=RepeatAlternp1)
        metadata_clean = metadata_clean.assign(
            Surprisenp1=Surprisenp1)  # Add an intercept for later

    epochs_clean.metadata = metadata_clean
    epochs_clean.save(fname_clean, overwrite=True)

    return True
Example #2
0
def append_surprise_to_metadata_clean(subject):
    """
    Load the metadata that contains the surprise for the non-clean epochs, removes the bad epochs from the metadata
    and this becomes the metadata for the clean epochs
    :param subject:
    :return:
    """

    meg_subject_dir = op.join(config.meg_dir, subject)
    if config.noEEG:
        meg_subject_dir = op.join(meg_subject_dir, 'noEEG')

    metadata_path = os.path.join(meg_subject_dir, 'metadata_item_clean.pkl')

    metadata = epoching_funcs.update_metadata(subject,
                                              clean=False,
                                              new_field_name=None,
                                              new_field_values=None,
                                              recompute=False)
    epochs = epoching_funcs.load_epochs_items(subject, cleaned=True)
    good_idx = [
        len(epochs.drop_log[i]) == 0 for i in range(len(epochs.drop_log))
    ]
    metadata_final = metadata[good_idx]

    with open(metadata_path, 'wb') as fid:
        pickle.dump(metadata_final, fid)

    return True
Example #3
0
def run_linear_regression_surprises(subject,
                                    omega_list,
                                    clean=False,
                                    decim=None,
                                    prefix='',
                                    Ridge=False,
                                    hfilter=20):

    epochs = epoching_funcs.load_epochs_items(subject, cleaned=clean)
    epochs.pick_types(meg=True, eeg=True)
    if hfilter is not None:
        epochs.filter(None, hfilter)

    if decim is not None:
        epochs.decimate(decim)

    metadata = epoching_funcs.update_metadata(subject,
                                              clean=clean,
                                              new_field_name=None,
                                              new_field_values=None)
    epochs.metadata = metadata
    df = epochs.metadata
    epochs.metadata = df.assign(Intercept=1)
    r2_surprise = {omega: [] for omega in omega_list}
    r2_surprise['times'] = epochs.times
    epochs_for_reg = epochs[np.where(
        1 - np.isnan(epochs.metadata["surprise_1"].values))[0]]
    epochs_for_reg = epochs_for_reg["SequenceID != 1"]
    epochs_for_reg_normalized = normalize_data(epochs_for_reg)

    out_path = op.join(config.result_path, 'TP_effects', 'surprise_omegas',
                       subject)
    utils.create_folder(out_path)
    if not Ridge:
        for omega in omega_list:
            print("==== running the regression for omega %i =======" % omega)
            surprise_name = "surprise_%.005f" % omega
            r2_surprise[omega] = linear_regression_from_sklearn(
                epochs_for_reg_normalized, surprise_name)
        # ===== save all the regression results =========
        fname = prefix + 'results_surprise.npy'
        np.save(op.join(out_path, fname), r2_surprise)

    else:
        surprise_names = ["surprise_%i" % omega for omega in omega_list]
        results_ridge = multi_ridge_regression_allIO(epochs_for_reg_normalized,
                                                     surprise_names)
        fname = prefix + 'results_Ridge_surprise.npy'
        np.save(op.join(out_path, fname), results_ridge)

    return True
Example #4
0
def extract_good_epochs_for_RSA(subject,
                                tmin,
                                tmax,
                                baseline,
                                decim,
                                clean,
                                recompute=True):
    """
    This function computes and saves the epochs epoched for the RSA.
    :param subject:
    :param tmin:
    :param tmax:
    :param baseline:
    :param decim:
    :param reject:
    :param which_analysis:
    :return:
    """

    if recompute:
        print("=== we are recomputing the epochs but not saving them ! ===")
        if clean:
            epochs = epoching_funcs.run_epochs(subject,
                                               epoch_on_first_element=False,
                                               tmin=tmin,
                                               tmax=tmax,
                                               baseline=baseline,
                                               whattoreturn='ARglobal')
        else:
            epochs = epoching_funcs.run_epochs(subject,
                                               epoch_on_first_element=False,
                                               tmin=tmin,
                                               tmax=tmax,
                                               baseline=baseline,
                                               whattoreturn='')
    else:
        epochs = epoching_funcs.load_epochs_items(subject, cleaned=clean)

    epochs.pick_types(meg=True)
    epochs.crop(tmin, tmax)
    if decim is not None:
        epochs.decimate(decim)
    if baseline is not None:
        epochs.apply_baseline(True)
    epochs = epochs[
        "TrialNumber > 10 and ViolationInSequence == 0 and StimPosition > 1"]

    return epochs
def compute_explained_variance(subject,clean=True,fname='residual_blabla-epo.fif'):

    import sklearn
    from sklearn.metrics import r2_score


    epochs = epoching_funcs.load_epochs_items(subject,cleaned=clean)
    metadata = epoching_funcs.update_metadata(subject, clean=clean, new_field_name=None, new_field_values=None)
    epochs.metadata = metadata
    epochs.pick_types(meg=True, eeg=True)
    epochs = epochs[np.where(1 - np.isnan(epochs.metadata["surprise_1"].values))[0]]
    y_true = epochs.get_data()


    epo_res = mne.read_epochs(op.join(config.meg_dir, subject,fname))
    # res = y_true - y_pred => y_pred = y_true - res

    y_pred = y_true - epo_res.get_data()


    R2 = r2_score(y_true,y_pred)
Example #6
0
def localize_standard_VS_deviant_code(subject,n_permutations = 2000,n_channels = 30,select_grad=False,cleaned=True):

    # ----------- load the epochs ---------------
    epochs = epoching_funcs.load_epochs_items(subject, cleaned=cleaned)
    epochs.pick_types(meg=True)

    # ----------- balance the position of the standard and the deviants -------
    # 'local' - Just make sure we have the same amount of standards and deviants for a given position. This may end up with
    #     1 standards/deviants for position 9 and 4 for the others.
    epochs_balanced = epoching_funcs.balance_epochs_violation_positions(epochs,balance_param="local")
    # ----------- do a sliding window to smooth the data -------
    epochs_balanced = epoching_funcs.sliding_window(epochs_balanced)

    # =============================================================================================
    toi = 0.165
    epochs_for_decoding = epochs_balanced.copy().crop(tmin=toi, tmax = toi)
    training_inds, testing_inds = SVM_funcs.train_test_different_blocks(epochs_for_decoding, return_per_seq=False)
    y_violornot = np.asarray(epochs_for_decoding.metadata['ViolationOrNot'].values)
    labels_train = [y_violornot[training_inds[i]] for i in range(2)]
    labels_test = [y_violornot[testing_inds[i]] for i in range(2)]

    performance_loc = compute_sensor_weights_decoder(epochs_for_decoding,
                                                          SVM_funcs.SVM_decoder(),
                                                          training_inds,
                                                          labels_train,
                                                          testing_inds,
                                                          labels_test, None,
                                                          None, n_permutations,
                                                          n_channels,select_grad=select_grad)

    suffix = ''
    if select_grad:
        suffix = 'only_grad'

    save_path = config.result_path + '/localization/Standard_VS_Deviant/'
    utils.create_folder(save_path)
    save_path_subject = save_path + subject + '/'+suffix
    utils.create_folder(save_path_subject)

    np.save(save_path_subject + 'results'+str(n_permutations)+'_permut'+str(n_channels)+'_chans'+'_'+str(round(toi*1000))+'.npy', performance_loc)
Example #7
0
def filter_good_epochs_for_regression_analysis(subject,
                                               clean=True,
                                               fields_of_interest=[
                                                   'surprise_100',
                                                   'RepeatAlternp1'
                                               ]):
    """
    This function removes the epochs that have Nans in the fields of interest specified in the list
    """
    epochs = epoching_funcs.load_epochs_items(subject, cleaned=clean)
    if fields_of_interest is not None:
        for field in fields_of_interest:
            epochs = epochs[np.where(
                1 - np.isnan(epochs.metadata[field].values))[0]]
            print(
                "--- removing the epochs that have Nan values for field %s ----\n"
                % field)

    if config.noEEG:
        epochs = epochs.pick_types(meg=True, eeg=False)
    else:
        epochs = epochs.pick_types(meg=True, eeg=True)

    return epochs
Example #8
0
def plot_features_from_metadata(sequences=[3, 4, 5, 6, 7]):

    figures_path = config.fig_path + '/features_figs/'

    # load metadata subject 1
    epo = epoching_funcs.load_epochs_items(config.subjects_list[0],
                                           cleaned=False)
    metadata = epo.metadata

    metadata_all = []
    for seqID in sequences:
        print(seqID)
        meta_all_seq = []
        for posinSeq in range(1, 17):
            meta_1 = metadata.query(
                "SequenceID == '%i' and StimID == 1 and ViolationInSequence == 0 and StimPosition == '%i' and TrialNumber == 1 "
                % (seqID, posinSeq))
            meta_all_seq.append(meta_1)
        meta_all_seq = pd.concat(meta_all_seq)
        metadata_all.append(meta_all_seq)

    for feature_name in [
            'StimID', 'Complexity', 'GlobalEntropy', 'StimPosition',
            'RepeatAlter', 'ChunkNumber', 'WithinChunkPosition',
            'WithinChunkPositionReverse', 'ChunkDepth', 'OpenedChunks',
            'ClosedChunks', 'ChunkBeginning', 'ChunkEnd', 'ChunkSize'
    ]:
        # Plot
        # Prepare colors range
        cm = plt.get_cmap('viridis')
        metadata_allseq = pd.concat(metadata_all)
        metadata_allseq_reg = metadata_allseq[feature_name]
        minvalue = np.nanmin(metadata_allseq_reg)
        maxvalue = np.nanmax(metadata_allseq_reg)
        # Open figure
        if len(sequences) == 5:
            fig, ax = plt.subplots(5,
                                   1,
                                   figsize=(8.7, 4.4),
                                   sharex=False,
                                   sharey=True,
                                   constrained_layout=True)
        else:
            fig, ax = plt.subplots(len(sequences),
                                   1,
                                   figsize=(8.7, 6),
                                   sharex=False,
                                   sharey=True,
                                   constrained_layout=True)
        fig.suptitle(feature_name, fontsize=12)
        # Plot each sequences with circle color corresponding to regressor value
        for nseq, seqs in enumerate(sequences):

            seqname, seqtxtXY, violation_positions = epoching_funcs.get_seqInfo(
                seqs)
            ax[nseq].set_title(seqname, loc='left', weight='bold', fontsize=12)
            metadata = metadata_all[nseq][feature_name]
            # Normalize between 0 and 1 based on possible values across sequences, in order to set the color
            metadata = (metadata - minvalue) / (maxvalue - minvalue)
            # stimID is always 1, so we use seqtxtXY instead...
            if feature_name == 'StimID':
                for ii in range(len(seqtxtXY)):
                    if seqtxtXY[ii] == 'x':
                        metadata[metadata.index[ii]] = 0
                    elif seqtxtXY[ii] == 'Y':
                        metadata[metadata.index[ii]] = 1
            for stimpos in range(0, 16):
                value = metadata[metadata.index[stimpos]]
                if ~np.isnan(value):
                    circle = plt.Circle((stimpos + 1, 0.5),
                                        0.4,
                                        facecolor=cm(value),
                                        edgecolor='k',
                                        linewidth=1)
                else:
                    circle = plt.Circle((stimpos + 1, 0.5),
                                        0.4,
                                        facecolor='white',
                                        edgecolor='k',
                                        linewidth=1)
                ax[nseq].add_artist(circle)
            ax[nseq].set_xlim([0, 17])
            for key in ('top', 'right', 'bottom', 'left'):
                ax[nseq].spines[key].set(visible=False)
            ax[nseq].set_xticks([], [])
            ax[nseq].set_yticks([], [])
        # Add "xY" using the same yval for all
        ylim = ax[nseq].get_ylim()
        yval = ylim[1] - ylim[1] * 0.1
        for nseq, seqs in enumerate(sequences):
            seqname, seqtxtXY, violation_positions = epoching_funcs.get_seqInfo(
                seqs)
            print(seqname)
            for xx in range(16):
                ax[nseq].text(xx + 1,
                              0.5,
                              seqtxtXY[xx],
                              horizontalalignment='center',
                              verticalalignment='center',
                              fontsize=12)

        suffix = ''
        if len(sequences) == 5:
            suffix = '_withoutSeqID12'
        fig_name = op.join(figures_path,
                           feature_name + '_regressor' + suffix + '.png')
        print('Saving ' + fig_name)
        plt.savefig(fig_name, bbox_inches='tight', dpi=300)
        plt.close(fig)
Example #9
0
import os
import numpy as np
import scipy.io as sio
import pandas as pd
import config
import os.path as op
import mne
import glob
import warnings
from autoreject import AutoReject
import pickle

# subject = config.subjects_list[11]
subject = 'sub08-cc_150418'
meg_subject_dir = op.join(config.meg_dir, subject)
epochs = epoching_funcs.load_epochs_items(subject, cleaned=False)

# run autoreject "global" -> just get the thresholds
reject = get_rejection_threshold(epochs, ch_types=['mag', 'grad', 'eeg'])
epochs1 = epochs.copy().drop_bad(reject=reject)
fname = op.join(meg_subject_dir, 'epochs_globalAR-epo.fif')
print("Saving: ", fname)
epochs1.save(fname, overwrite=True)

# run autoreject "local"
ar = AutoReject()
epochs2, reject_log = ar.fit_transform(epochs, return_log=True)
fname = op.join(meg_subject_dir, 'epochs_localAR-epo.fif')
print("Saving: ", fname)
epochs2.save(fname, overwrite=True)
# Save autoreject reject_log
# fig = sns.heatmap(np.mean(omega_argmax,axis=0))
times_of_interest = [-0.1,0,0.1,0.2,0.3,0.4,0.5,0.6,0.7]
inds_t = np.hstack([np.where(omega_optimal['time']==tim)[0] for tim in times_of_interest])
plt.figure(figsize=(10, 20))
plt.imshow(np.mean(omega_argmax,axis=0), origin='lower')
plt.yticks(inds_t, times_of_interest)
plt.colorbar()
plt.title("Optimal Omega f(time,channel)")
plt.ylabel("Time in sec")
plt.xlabel("Channel index")



# ___________ plot like plot joint the optimal omega _______________________
data_to_plot = np.mean(omega_argmax,axis=0)
epoch = epoching_funcs.load_epochs_items(config.subjects_list[0])
average = epoch.average()
average._data = data_to_plot.T
average.plot_joint()

# __________ average across channels _________
plt.plot(omega_optimal['time'],np.mean(data_to_plot,axis=1))
plt.xticks(times_of_interest)
plt.xlabel('Time')
plt.ylabel('Optimal Omega')
plt.title('Optimal Omega averaged over channels')



df_posterior = TP_funcs.for_plot_posterior_probability(config.subjects_list,omega_list=range(1,299))
df_post = np.mean(df_posterior['posterior'],axis=0)
Example #11
0
def regress_surprise_in_cluster(subject_list,
                                cluster_info,
                                omega_list=range(1, 300),
                                clean=True):
    """
    This function regresses the data within a cluster as a function of the surprise for all the omegas specified in omega_list.
    4 different types of regressions are considered,
    1 - 'original_data' : for each channel and time-point
    2 - 'average_time' : averaging the data across time
    3 - 'average_channels' : averaging the data across channels
    4 - 'average_channels_and_times' : averaging the data across channels and time

    :param subject_list:
    :param cluster_info: a dictionnary containing  the keys
    'sig_times': the significant times
    'channels_cluster' : the channels that are significant
    'ch_type': the type of channel
    :param omega_list: The list of omegas for which we compute the regressions.
    :return: dataFrame containing the results of all the regressions
    """

    sig_times = cluster_info['sig_times']
    sig_channels = cluster_info['channels_cluster']
    ch_type = cluster_info['ch_type']

    results = {subject: {} for subject in subject_list}

    for subject in subject_list:
        results[subject] = {
            'average_channels_and_times': {},
            'average_channels': {},
            'average_times': {},
            'original_data': {}
        }

        epochs = epoching_funcs.load_epochs_items(subject, cleaned=clean)
        metadata = epoching_funcs.update_metadata(subject,
                                                  clean=clean,
                                                  new_field_name=None,
                                                  new_field_values=None)
        epochs.metadata = metadata

        if ch_type in ['grad', 'mag']:
            epochs.pick_types(meg=ch_type, eeg=False)
        elif ch_type in ['eeg']:
            epochs.pick_types(meg=False, eeg=True)
        else:
            print('Invalid ch_type')

        epochs = epochs[np.where(
            1 - np.isnan(epochs.metadata["surprise_1"].values))[0]]
        # ========= select the significant times and channels ======
        epochs.crop(tmin=np.min(sig_times), tmax=np.max(sig_times))
        epochs.pick(sig_channels)  # not sure this is working actually
        epochs_avg_time = epochs.copy()
        epochs_avg_time._data = np.transpose(
            np.mean(epochs_avg_time._data, axis=2)[:, np.newaxis], (0, 2, 1))
        epochs_avg_channels = epochs.copy()
        epochs_avg_channels._data = np.mean(epochs_avg_channels._data,
                                            axis=1)[:, np.newaxis]
        epochs_avg_channels_times = epochs.copy()
        epochs_avg_channels_times._data = np.mean(epochs_avg_time._data,
                                                  axis=1)[:, np.newaxis]

        # ============== And now the regressions =============================================================
        for key in results[subject].keys():
            results[subject][key] = {omega: {} for omega in omega_list}

        for omega in omega_list:
            print("==== running the regression for omega %i =======" % omega)
            surprise_name = "surprise_%i" % omega
            results[subject]['original_data'][
                omega] = linear_regression_from_sklearn(epochs, surprise_name)
            results[subject]['average_times'][
                omega] = linear_regression_from_sklearn(
                    epochs_avg_time, surprise_name)
            results[subject]['average_channels'][
                omega] = linear_regression_from_sklearn(
                    epochs_avg_channels, surprise_name)
            results[subject]['average_channels_and_times'][
                omega] = linear_regression_from_sklearn(
                    epochs_avg_channels_times, surprise_name)

    return results
Example #12
0
def regress_out_optimal_omega_per_channel(subject, clean=True):

    # =========== load the optimal parameters =========
    load_optimal = op.join(config.result_path, 'TP_effects', 'surprise_omegas',
                           'omega_optimal_per_channels.npy')
    optimal_omega = np.load(load_optimal, allow_pickle=True).item()
    optimal_omega = optimal_omega['omega_arg_max']
    optimal_omega = np.mean(optimal_omega, axis=0)

    # =========== load the data on which to perform the regression =========
    epochs = epoching_funcs.load_epochs_items(subject, cleaned=clean)
    metadata = epoching_funcs.update_metadata(subject,
                                              clean=clean,
                                              new_field_name=None,
                                              new_field_values=None)
    epochs.metadata = metadata
    epochs.pick_types(meg=True, eeg=True)
    epochs = epochs[np.where(
        1 - np.isnan(epochs.metadata["surprise_1"].values))[0]]
    y = epochs.get_data()

    # =========== we initialize the output =========
    n_trials, n_channels, n_times = y.shape
    residual_model_no_constant = np.zeros((n_trials, n_channels, n_times))
    residual_model_constant = np.zeros((n_trials, n_channels, n_times))
    residual_constant = np.zeros((n_trials, n_channels, n_times))
    residual_surprise = np.zeros((n_trials, n_channels, n_times))

    # ======== we run the regression for each time point and each channel ===============
    for time in range(y.shape[2]):
        for channel in range(y.shape[1]):
            print(
                "----- running the regression for time %i and channel %i -----"
                % (time, channel))
            surprise_name = "surprise_%i" % int(
                np.round(optimal_omega[time, channel], 0))
            x = np.asarray(epochs.metadata[surprise_name])
            x = x[:, np.newaxis]
            # ========== regression with constant ==============
            reg_with_constant = LinearRegression().fit(x, y[:, channel, time])
            # ========== regression without constant ==============
            reg_without_constant = LinearRegression(fit_intercept=False).fit(
                x, y[:, channel, time])

            residual_model_constant[:, channel,
                                    time] = y[:, channel,
                                              time] - reg_with_constant.predict(
                                                  x)
            residual_constant[:, channel,
                              time] = y[:, channel, time] - np.squeeze(
                                  reg_with_constant.intercept_ * x)
            residual_surprise[:, channel,
                              time] = y[:, channel, time] - np.squeeze(
                                  reg_with_constant.coef_ * x)
            residual_model_no_constant[:, channel,
                                       time] = y[:, channel,
                                                 time] - reg_without_constant.predict(
                                                     x)

    # ============================================================================================================

    epo_residual_model_constant = epochs.copy()
    epo_residual_constant = epochs.copy()
    epo_residual_surprise = epochs.copy()
    epo_residual_model_no_constant = epochs.copy()

    epo_residual_model_constant._data = residual_model_constant
    epo_residual_constant._data = residual_constant
    epo_residual_surprise._data = residual_surprise
    epo_residual_model_no_constant._data = residual_model_no_constant

    save_name = op.join(config.meg_dir, subject,
                        subject + 'residual_model_constant-epo.fif')
    epo_residual_model_constant.save(save_name)

    save_name = op.join(config.meg_dir, subject,
                        subject + 'residual_constant-epo.fif')
    epo_residual_constant.save(save_name)

    save_name = op.join(config.meg_dir, subject,
                        subject + 'residual_surprise-epo.fif')
    epo_residual_surprise.save(save_name)

    save_name = op.join(config.meg_dir, subject,
                        subject + 'residual_model_no_constant-epo.fif')
    epo_residual_model_no_constant.save(save_name)
Example #13
0
def regress_out_optimal_omega(subject, clean=True):
    """
    This function computes the regression for each time step and each optimal omega, i.e. that explains the most the variance
    :param epochs_for_reg: the epochs data on which we run the regression
    :param names: the regression variables
    :return:
    """

    save_name = op.join(config.result_path, 'TP_effects', 'surprise_omegas',
                        'argmax_omega.npy')
    omega_argmax = np.load(save_name)

    from sklearn.linear_model import LinearRegression
    from sklearn.metrics import r2_score

    epochs = epoching_funcs.load_epochs_items(subject, cleaned=clean)
    metadata = epoching_funcs.update_metadata(subject,
                                              clean=clean,
                                              new_field_name=None,
                                              new_field_values=None)
    epochs.metadata = metadata
    epochs.pick_types(meg=True, eeg=True)
    epochs = epochs[np.where(
        1 - np.isnan(epochs.metadata["surprise_1"].values))[0]]
    y = epochs.get_data()

    results = {}
    results['regcoeff_intercept'] = []
    results['regcoef_'] = []
    results['score'] = []
    results['omega'] = []
    results['residual'] = []
    results['times'] = epochs.times
    results['predictions'] = []
    results['score_per_channel'] = []

    for time in range(y.shape[2]):
        print("======= time step %i ==========" % time)
        surprise_name = "surprise_%i" % omega_argmax[time]
        x = np.asarray(epochs.metadata[surprise_name])
        reg = LinearRegression().fit(x[:, np.newaxis], y[:, :, time])
        results['regcoeff_intercept'].append(reg.intercept_)
        results['regcoef_'].append(reg.coef_)
        results['omega'].append(omega_argmax[time])
        y_pred = reg.predict(x[:, np.newaxis])
        r2 = [r2_score(y_pred[:, k], y[:, k, time]) for k in range(y.shape[1])]
        results['score_per_channel'].append(r2)
        results['score'].append(reg.score(x[:, np.newaxis], y[:, :, time]))
        results['predictions'].append(np.matmul(reg.coef_, x[:, np.newaxis].T))
        y_residual_time = y[:, :, time] - np.matmul(reg.coef_,
                                                    x[:, np.newaxis].T).T
        results['residual'].append(y_residual_time)

    for key in results.keys():
        results[key] = np.asarray(results[key])

    epochs_residual = epochs.copy()
    epochs_reg_coeff_surprise = epochs.copy()

    epochs_residual._data = np.transpose(results['residual'], (1, 2, 0))
    save_name = op.join(config.meg_dir, subject,
                        subject + '_residuals_surprise-epo.fif')
    epochs_residual.save(save_name)

    # ============= it is the topomap of this that is going to tell us the contribution of the topography of the variance ====
    epochs_reg_coeff_surprise._data = np.transpose(results['predictions'],
                                                   (1, 2, 0))
    save_name = op.join(config.meg_dir, subject,
                        subject + '_regcoeff_surprise-epo.fif')
    epochs_reg_coeff_surprise.save(save_name)

    res_fname = op.join(config.result_path, 'TP_effects', 'surprise_omegas',
                        subject, 'residuals_results.npy')
    np.save(res_fname, results)

    return results