Example #1
0
def load_train_test(clf,
                    filedir,
                    confuse_mat_path,
                    freq_h=15,
                    use_good_sensors=False,
                    decim=10,
                    n_jobs=12,
                    scoring='accuracy'):

    # initial running timer
    st = simple_timer()

    # parameters setting
    fname_list, ortids, event_ids, tmin, t0, tmax = para_setting(
        train=True, filedir=filedir)

    # load raw data and epochs
    epochs_run = []
    for fname in fname_list:
        print(fname)
        epochs, raw = get_epochs(fname=fname,
                                 event_id=event_ids,
                                 tmin=tmin,
                                 t0=t0,
                                 tmax=tmax,
                                 freq_l=1,
                                 freq_h=freq_h,
                                 decim=decim,
                                 use_good_sensors=use_good_sensors,
                                 get_envlop=False)
        epochs_run.append(epochs)
        st.click()

    num_repeat = 100
    num_ort = 6
    num_cross = 5
    num_timepoint = epochs.get_data().shape[-1]
    confuse_mat = np.zeros([
        num_ort, num_ort, num_timepoint, num_timepoint, num_cross, num_repeat
    ])
    # stack data
    X_all = np.vstack(epochs_run[j].get_data() for j in range(5))
    y_all = np.vstack(epochs_run[j].events for j in range(5))[:, 2]
    idx_list = np.unique(y_all)
    n = len(X_all)

    st.click()

    for rep_ in range(num_repeat):
        # shuffle data
        s_ = np.random.permutation(range(n))
        X_shuff = X_all.copy()[s_]
        y_shuff = y_all.copy()[s_]

        # poke data into different orts(referred as idx_)
        X_dict = {}
        y_dict = {}
        for i in range(len(idx_list)):
            idx_ = idx_list[i]
            tmp = X_shuff[y_shuff == idx_]
            X_dict[idx_] = np.vstack(
                np.expand_dims(np.mean(tmp[j * 12 + 0:j * 12 + 12], 0), 0)
                for j in range(5))
            y_dict[idx_] = np.vstack(i + 1 + np.zeros(len(X_dict[idx_])))

        # for each combin, seperate train and test data
        for combin_ in itertools.combinations(range(len(idx_list)), 2):
            combin = list(idx_list[j] for j in combin_)
            print(rep_, combin_, combin)
            for cross_ in range(5):
                cross_train = [0, 1, 2, 3, 4]
                cross_train.pop(cross_)
                X_train = np.vstack(X_dict[j][cross_train] for j in combin)
                y_train = np.ravel(
                    np.vstack(y_dict[j][cross_train] for j in combin))
                X_test = np.vstack(
                    np.expand_dims(X_dict[j][cross_], 0) for j in combin)
                y_test = np.ravel(
                    np.vstack(
                        np.expand_dims(y_dict[j][cross_], 0) for j in combin))
                # train and test
                scores = train_test(X_train,
                                    y_train,
                                    X_test,
                                    y_test,
                                    clf=clf,
                                    scoring=scoring,
                                    n_jobs=n_jobs)
                confuse_mat[combin_[0], combin_[1], :, :, cross_,
                            rep_] = scores
                st.click()

    np.save(confuse_mat_path, confuse_mat)
from tools_loaddata import para_setting, get_epochs

sys.path.append('C:\\Users\\liste\\Documents\\Python Scripts\\clock_tools')
from simple_timer import simple_timer

freq_h = 200
decim = 1
use_good_sensors = False

filedir = 'D:/BeidaShuju/rawdata/%s' % 'ZYF'
# initial running timer
st = simple_timer()

# parameters setting
fname_list, ortids, event_ids, tmin, t0, tmax = para_setting(train=True,
                                                             filedir=filedir)

# define frequencies of interest (log-spaced)
freqs = np.linspace(10, 80, num=20)
n_cycles = freqs / 2.  # different number of cycle per frequency

# load raw data and epochs
epochs_run = []
for fname in fname_list:
    print(fname)
    epochs, raw = get_epochs(fname=fname,
                             event_id=event_ids,
                             tmin=tmin,
                             t0=t0,
                             tmax=tmax,
                             freq_l=1,
Example #3
0
                    extent=times[[0, -1, 0, -1]])
    ax.axhline(0., color='k')
    ax.axvline(0., color='k')
    ax.xaxis.set_ticks_position('bottom')
    ax.set_xlabel('Testing Time (s)')
    ax.set_ylabel('Training Time (s)')
    ax.set_title('Generalization across time and condition')
    plt.colorbar(im, ax=ax)


# initial running timer
st = simple_timer()

# parameters setting
train = True
fname_list, ortids, event_ids, tmin, t0, tmax = para_setting(train=train)

# load raw data and epochs
epochs_run = []
for fname in fname_list:
    print(fname)
    epochs, raw = get_epochs(fname=fname, event_id=event_ids,
                             tmin=tmin, t0=t0, tmax=tmax,
                             freq_l=1, freq_h=5,
                             decim=10,
                             use_good_sensors=False,
                             get_envlop=False)
    epochs_run.append(epochs)
    st.click()

num_repeat = 100