Пример #1
0
def load_sentence_data(subject, word, sen_type, experiment, proc, num_instances, reps_to_use, sorted_inds=None):
    evokeds, labels, time, sen_ids = load_data.load_raw(subject, word, sen_type,
                                                        experiment=experiment, proc=proc)
    data, labels, sen_ids = load_data.avg_data(evokeds, labels, sentence_ids_raw=sen_ids,
                                               experiment=experiment,
                                               num_instances=num_instances,
                                               reps_to_use=reps_to_use)
    labels = np.array(labels)
    label_sort_inds = np.argsort(sen_ids)
    labels = labels[label_sort_inds]
    data = data[label_sort_inds, :, :]

    if sorted_inds is not None:
        data = data[:, sorted_inds, :]

    return data, labels, time
Пример #2
0
def run_tgm_exp(experiment,
                subject,
                sen_type,
                word,
                win_len,
                overlap,
                mode='pred',
                isPDTW=False,
                isPerm=False,
                num_folds=2,
                alg='GNB',
                doFeatSelect=False,
                doZscore=False,
                doAvg=False,
                num_instances=2,
                reps_to_use=10,
                proc=load_data.DEFAULT_PROC,
                random_state_perm=1,
                random_state_cv=CV_RAND_STATE,
                random_state_sub=SUB_CV_RAND_STATE,
                force=False):
    warnings.filterwarnings('error')
    # Save Directory
    top_dir = TOP_DIR.format(exp=experiment)
    if not os.path.exists(top_dir):
        os.mkdir(top_dir)
    save_dir = SAVE_DIR.format(top_dir=top_dir, sub=subject)
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    if alg not in VALID_ALGS:
        raise ValueError('invalid alg {}: must be {}'.format(alg, VALID_ALGS))
    if sen_type not in VALID_SEN_TYPE:
        raise ValueError('invalid sen_type {}: must be {}'.format(
            sen_type, VALID_SEN_TYPE))
    if mode not in VALID_MODE:
        raise ValueError('invalid mode {}: must be {}'.format(
            mode, VALID_MODE))

    if alg == 'GNB' and doFeatSelect:
        alg_str = alg + '-FS'
    else:
        alg_str = alg

    fname = SAVE_FILE.format(dir=save_dir,
                             sub=subject,
                             sen_type=sen_type,
                             word=word,
                             win_len=win_len,
                             overlap=overlap,
                             pdtw=bool_to_str(isPDTW),
                             perm=bool_to_str(isPerm),
                             num_folds=num_folds,
                             alg=alg_str,
                             zscore=bool_to_str(doZscore),
                             doAvg=bool_to_str(doAvg),
                             inst=num_instances,
                             rep=reps_to_use,
                             rsP=random_state_perm,
                             rsC=random_state_cv,
                             rsS=random_state_sub,
                             mode=mode)

    if os.path.isfile(fname + '.npz') and not force:
        print('Job already completed. Skipping Job.')
        print(fname)
        return

    if isPDTW:
        (time_a, time_p, labels, active_data_raw,
         passive_data_raw) = load_data.load_pdtw(subject=subject,
                                                 word=word,
                                                 experiment=experiment,
                                                 proc=proc)
        if sen_type == 'active':
            data_raw = active_data_raw
            time = time_a
        else:
            data_raw = passive_data_raw
            time = time_p
    else:
        data_raw, labels, time, _ = load_data.load_raw(subject=subject,
                                                       word=word,
                                                       sen_type=sen_type,
                                                       experiment=experiment,
                                                       proc=proc)

    print(data_raw.shape)

    data, labels, _ = load_data.avg_data(data_raw=data_raw,
                                         labels_raw=labels,
                                         experiment=experiment,
                                         num_instances=num_instances,
                                         reps_to_use=reps_to_use)
    print(data.shape)

    if isPerm:
        random.seed(random_state_perm)
        random.shuffle(labels)

    tmin = time.min()
    tmax = time.max()

    if num_folds > 8:
        kf = KFold(n_splits=num_folds,
                   shuffle=True,
                   random_state=random_state_cv)
    else:
        kf = StratifiedKFold(n_splits=num_folds,
                             shuffle=True,
                             random_state=random_state_cv)

    total_win = int((tmax - tmin) * 500)

    if win_len < 0:
        win_len = total_win - overlap

    win_starts = range(0, total_win - win_len, overlap)

    assert total_win <= len(time)

    # Run TGM
    if mode == 'pred':
        if alg == 'LASSO':
            (preds, l_ints, cv_membership,
             masks) = models.lr_tgm(data=data,
                                    labels=labels,
                                    kf=kf,
                                    win_starts=win_starts,
                                    win_len=win_len,
                                    doZscore=doZscore,
                                    doAvg=doAvg)
            np.savez_compressed(fname,
                                preds=preds,
                                l_ints=l_ints,
                                cv_membership=cv_membership,
                                masks=masks,
                                time=time,
                                win_starts=win_starts,
                                proc=proc)
        elif alg == 'GNB':
            (preds, l_ints, cv_membership, masks,
             num_feats) = models.nb_tgm(data=data,
                                        labels=labels,
                                        kf=kf,
                                        sub_rs=random_state_sub,
                                        win_starts=win_starts,
                                        win_len=win_len,
                                        feature_select=doFeatSelect,
                                        doZscore=doZscore,
                                        doAvg=doAvg)
            print(num_feats)
            np.savez_compressed(fname,
                                preds=preds,
                                l_ints=l_ints,
                                cv_membership=cv_membership,
                                masks=masks,
                                time=time,
                                win_starts=win_starts,
                                proc=proc,
                                num_feats=num_feats)
        else:
            raise ValueError('ENET not implemented yet.')
    elif mode == 'uni':
        if alg == 'GNB':
            (preds, l_ints,
             cv_membership) = models.nb_tgm_uni(data=data,
                                                labels=labels,
                                                kf=kf,
                                                win_starts=win_starts,
                                                win_len=win_len,
                                                doZscore=doZscore,
                                                doAvg=doAvg)
            np.savez_compressed(fname,
                                preds=preds,
                                l_ints=l_ints,
                                cv_membership=cv_membership,
                                time=time,
                                win_starts=win_starts,
                                proc=proc)
        else:
            raise ValueError('Only GNB available.')
    else:
        if alg == 'LASSO':
            coef = models.lr_tgm_coef(data=data,
                                      labels=labels,
                                      win_starts=win_starts,
                                      win_len=win_len,
                                      doZscore=doZscore,
                                      doAvg=doAvg)
            np.savez_compressed(fname,
                                coef=coef,
                                time=time,
                                win_starts=win_starts,
                                proc=proc)
        elif alg == 'GNB':
            mu_win, std_win, mu_diff_win = models.nb_tgm_coef(
                data=data,
                labels=labels,
                win_starts=win_starts,
                win_len=win_len,
                doZscore=doZscore,
                doAvg=doAvg)
            np.savez_compressed(fname,
                                mu_win=mu_win,
                                std_win=std_win,
                                mu_diff_win=mu_diff_win,
                                time=time,
                                win_starts=win_starts,
                                proc=proc)
        else:
            raise ValueError('ENET not implemented yet.')
Пример #3
0
def run_diag_exp(experiment,
                 subject,
                 sen_type,
                 word,
                 overlap,
                 mode='pred',
                 isPDTW=False,
                 isPerm=False,
                 num_folds=2,
                 alg='GNB',
                 doZscore=False,
                 doAvg=False,
                 num_instances=2,
                 reps_to_use=10,
                 proc=load_data.DEFAULT_PROC,
                 random_state_perm=1,
                 random_state_cv=CV_RAND_STATE,
                 random_state_sub=SUB_CV_RAND_STATE,
                 force=False):
    # Save Directory
    top_dir = TOP_DIR.format(exp=experiment)
    if not os.path.exists(top_dir):
        os.mkdir(top_dir)
    save_dir = SAVE_DIR.format(top_dir=top_dir, sub=subject)
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    if alg not in VALID_ALGS:
        raise ValueError('invalid alg: must be {}'.format(VALID_ALGS))
    if sen_type not in VALID_SEN_TYPE:
        raise ValueError('invalid sen_type: must be {}'.format(VALID_SEN_TYPE))
    if sen_type not in VALID_MODE:
        raise ValueError('invalid mode: must be {}'.format(VALID_MODE))

    fname = SAVE_FILE.format(dir=save_dir,
                             sub=subject,
                             sen_type=sen_type,
                             word=word,
                             overlap=overlap,
                             pdtw=bool_to_str(isPDTW),
                             perm=bool_to_str(isPerm),
                             num_folds=num_folds,
                             alg=alg,
                             zscore=bool_to_str(doZscore),
                             doAvg=bool_to_str(doAvg),
                             inst=num_instances,
                             rep=reps_to_use,
                             rsP=random_state_perm,
                             rsC=random_state_cv,
                             rsS=random_state_sub,
                             mode=mode)

    if os.path.isfile(fname) and not force:
        print('Job already completed. Skipping Job.')
        print(fname)
        return

    if isPDTW:
        (time_a, time_p, labels, active_data_raw,
         passive_data_raw) = load_data.load_pdtw(subject=subject,
                                                 word=word,
                                                 experiment=experiment,
                                                 proc=proc)
        if sen_type == 'active':
            data_raw = active_data_raw
            time = time_a
        else:
            data_raw = passive_data_raw
            time = time_p
    else:
        data_raw, labels, time = load_data.load_raw(subject=subject,
                                                    word=word,
                                                    sen_type=sen_type,
                                                    experiment=experiment,
                                                    proc=proc)

    data, labels = load_data.avg_data(data_raw=data_raw,
                                      labels_raw=labels,
                                      experiment=experiment,
                                      num_instances=num_instances,
                                      reps_to_use=reps_to_use)
    print(data.shape)

    if isPerm:
        random.seed(random_state_perm)
        random.shuffle(labels)

    tmin = time.min()
    tmax = time.max()

    kf = StratifiedKFold(n_splits=num_folds,
                         shuffle=True,
                         random_state=CV_RAND_STATE)
    sub_kf = StratifiedKFold(n_splits=2,
                             shuffle=True,
                             random_state=SUB_CV_RAND_STATE)

    total_win = int((tmax - tmin) * 500)
    win_starts = range(0, total_win, overlap)

    if total_win > len(time):
        raise ValueError('Windows are messed up.')

    # Run TGM
    if mode == 'pred':
        if alg == 'LASSO':
            (preds, l_ints, cv_membership,
             masks) = models.lasso_tgm(data=data,
                                       labels=labels,
                                       kf=kf,
                                       sub_kf=sub_kf,
                                       win_starts=win_starts,
                                       doZscore=doZscore,
                                       doAvg=doAvg)
        elif alg == 'GNB':
            (preds, l_ints, cv_membership, masks) = models.nb_tgm(
                data=data,
                labels=labels,
                kf=kf,
                sub_kf=sub_kf,
                win_starts=win_starts,
                feature_select='distance_of_means',
                feature_select_params={'number_of_features': num_feats},
                doZscore=doZscore,
                doAvg=doAvg)
        else:
            (preds, l_ints, cv_membership,
             masks) = models.enet_tgm(data=data,
                                      labels=labels,
                                      kf=kf,
                                      sub_kf=sub_kf,
                                      win_starts=win_starts,
                                      doZscore=doZscore,
                                      doAvg=doAvg)

        np.savez_compressed(fname,
                            preds=preds,
                            l_ints=l_ints,
                            cv_membership=cv_membership,
                            masks=masks,
                            time=time,
                            win_starts=win_starts,
                            proc=proc)

    elif mode == 'coef':
        win_lens = load_data.load_win_lens(experiment, subject, sen_type, word,
                                           overlap, isPDTW, num_folds, alg,
                                           doZscore, doAvg, num_instances,
                                           reps_to_use, proc, random_state_cv,
                                           random_state_sub)
        if alg == 'LASSO':
            coef = models.lasso_tgm_coef(data=data,
                                         labels=labels,
                                         win_starts=win_starts,
                                         win_lens=win_lens,
                                         doZscore=doZscore,
                                         doAvg=doAvg)
        elif alg == 'GNB':
            coef = models.nb_tgm_coef(
                data=data,
                labels=labels,
                win_starts=win_starts,
                win_len=win_lens,
                feature_select='distance_of_means',
                feature_select_params={'number_of_features': num_feats},
                doZscore=False,
                doAvg=False,
                ddof=1)
        else:
            coef = models.enet_tgm_coef(data=data,
                                        labels=labels,
                                        win_starts=win_starts,
                                        win_lens=win_lens,
                                        doZscore=doZscore,
                                        doAvg=doAvg)

        np.savez_compressed(fname,
                            coef=coef,
                            time=time,
                            win_starts=win_starts,
                            proc=proc)
Пример #4
0
def run_sv_exp(experiment,
               subject,
               sen_type,
               word,
               model='one_hot',
               inc_art1=False,
               inc_art2=False,
               only_art1=False,
               only_art2=False,
               direction='encoding',
               doPCA=False,
               isPDTW=False,
               isPerm=False,
               num_folds=2,
               alg='ridge',
               adj='mean_center',
               num_instances=1,
               reps_to_use=10,
               proc=load_data.DEFAULT_PROC,
               random_state_perm=1,
               random_state_cv=CV_RAND_STATE,
               random_state_sub=SUB_CV_RAND_STATE,
               force=False):
    # warnings.filterwarnings('error')
    # Save Directory
    top_dir = TOP_DIR.format(exp=experiment)
    if not os.path.exists(top_dir):
        os.mkdir(top_dir)
    save_dir = SAVE_DIR.format(top_dir=top_dir, sub=subject)
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    if alg not in VALID_ALGS:
        raise ValueError('invalid alg {}: must be {}'.format(alg, VALID_ALGS))
    if sen_type not in VALID_SEN_TYPE:
        raise ValueError('invalid sen_type {}: must be {}'.format(
            sen_type, VALID_SEN_TYPE))

    if only_art1 and not only_art2:
        art1_str = 'O'
        art2_str = 'F'
    elif only_art2 and not only_art1:
        art2_str = 'O'
        art1_str = 'F'
    elif only_art1 and only_art2:
        art2_str = 'O'
        art1_str = 'O'
    else:
        art1_str = bool_to_str(inc_art1)
        art2_str = bool_to_str(inc_art2)

    fname = SAVE_FILE.format(dir=save_dir,
                             sub=subject,
                             sen_type=sen_type,
                             word=word,
                             model=model,
                             art1=art1_str,
                             art2=art2_str,
                             direction=direction,
                             pca=bool_to_str(doPCA),
                             pdtw=bool_to_str(isPDTW),
                             perm=bool_to_str(isPerm),
                             num_folds=num_folds,
                             alg=alg,
                             adj=adj,
                             inst=num_instances,
                             rep=reps_to_use,
                             rsP=random_state_perm,
                             rsC=random_state_cv,
                             rsS=random_state_sub)

    if os.path.isfile(fname + '.npz') and not force:
        print('Job already completed. Skipping Job.')
        print(fname)
        return

    if isPDTW:
        (time_a, time_p, labels, active_data_raw,
         passive_data_raw) = load_data.load_pdtw(subject=subject,
                                                 word=word,
                                                 experiment=experiment,
                                                 proc=proc)
        if sen_type == 'active':
            data_raw = active_data_raw
            time = time_a
        else:
            data_raw = passive_data_raw
            time = time_p
        sentence_ids = range(data_raw.shape[0])
    else:
        data_raw, labels, time, sentence_ids = load_data.load_raw(
            subject=subject,
            word=word,
            sen_type=sen_type,
            experiment=experiment,
            proc=proc)

    print(data_raw.shape)

    data, labels, sentence_ids = load_data.avg_data(
        data_raw=data_raw,
        labels_raw=labels,
        sentence_ids_raw=sentence_ids,
        experiment=experiment,
        num_instances=num_instances,
        reps_to_use=reps_to_use)
    print(data.shape)

    l_set = np.unique(labels)
    n_l = len(l_set)
    l_index = {l_set[i]: i for i in xrange(n_l)}
    l_ints = np.array([l_index[l] for l in labels])

    if model == 'glove':
        semantic_vectors = load_data.load_glove_vectors(labels)
    else:
        if only_art1 and not only_art2:
            semantic_vectors = load_data.load_one_hot(
                load_data.get_arts_from_senid(sentence_ids, 1))
        elif only_art2 and not only_art1:
            semantic_vectors = load_data.load_one_hot(
                load_data.get_arts_from_senid(sentence_ids, 2))
        elif only_art1 and only_art2:
            semantic_vectors = np.concatenate([
                load_data.load_one_hot(
                    load_data.get_arts_from_senid(sentence_ids, 1)),
                load_data.load_one_hot(
                    load_data.get_arts_from_senid(sentence_ids, 2))
            ],
                                              axis=1)
        else:
            semantic_vectors = load_data.load_one_hot(labels)
            if inc_art1:
                semantic_vectors = np.concatenate([
                    semantic_vectors,
                    load_data.load_one_hot(
                        load_data.get_arts_from_senid(sentence_ids, 1))
                ],
                                                  axis=1)
            if inc_art2:
                semantic_vectors = np.concatenate([
                    semantic_vectors,
                    load_data.load_one_hot(
                        load_data.get_arts_from_senid(sentence_ids, 2))
                ],
                                                  axis=1)

    if doPCA:
        if direction == 'encoding':
            pca = PCA()
            pca.fit(semantic_vectors)
            semantic_vectors *= np.transpose(pca.components_)
        else:
            reshaped_data = np.reshape(data, (data.shape[0], -1))
            pca = PCA()
            pca.fit(semantic_vectors)
            semantic_vectors *= np.transpose(pca.components_)

    if isPerm:
        random.seed(random_state_perm)
        random.shuffle(labels)

    if num_folds > 8:
        kf = KFold(n_splits=num_folds,
                   shuffle=True,
                   random_state=random_state_cv)
    else:
        kf = StratifiedKFold(n_splits=num_folds,
                             shuffle=True,
                             random_state=random_state_cv)

    preds, l_ints, cv_membership, scores = models.lin_reg(data,
                                                          semantic_vectors,
                                                          l_ints,
                                                          kf,
                                                          reg=alg,
                                                          adj=adj,
                                                          ddof=1)
    np.savez_compressed(fname,
                        preds=preds,
                        l_ints=l_ints,
                        cv_membership=cv_membership,
                        scores=scores,
                        time=time,
                        proc=proc)
Пример #5
0
    parser.add_argument('--isPDTW', default='False')
    parser.add_argument('--num_instances', type=int, default=1)
    parser.add_argument('--reps_to_use', type=int, default=10)
    parser.add_argument('--proc', default=load_data.DEFAULT_PROC)
    args = parser.parse_args()

    evokeds, labels, time, _ = load_data.load_raw(args.subject,
                                                  args.word,
                                                  args.sen_type,
                                                  experiment=args.experiment,
                                                  proc=args.proc)
    evokeds = baseline_correct(evokeds, time)

    avg_data, labels_avg, _ = load_data.avg_data(
        evokeds,
        labels,
        experiment=args.experiment,
        num_instances=args.num_instances,
        reps_to_use=args.reps_to_use)

    sorted_inds, sorted_reg = sort_sensors()
    avg_data = avg_data[:, sorted_inds, :]
    uni_reg = np.unique(sorted_reg)
    yticks_sens = [sorted_reg.index(reg) for reg in uni_reg]

    num_time = time.size

    for i in range(2):
        fig, ax = plt.subplots()
        h = ax.imshow(np.squeeze(avg_data[i, :, :]),
                      interpolation='nearest',
                      aspect='auto',
Пример #6
0
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--experiment', default='krns2')
    parser.add_argument('--subject', default='B')
    parser.add_argument('--sen_type', default='active')
    parser.add_argument('--word', default='firstNoun')
    parser.add_argument('--isPDTW', default='False')
    parser.add_argument('--num_instances', type=int, default=1)
    parser.add_argument('--reps_to_use', type=int, default=10)
    parser.add_argument('--proc', default=load_data.DEFAULT_PROC)
    args = parser.parse_args()

    save_dir = SAVE_DIR.format(args.experiment)

    evokeds, labels, time, sentence_ids = load_data.load_raw(args.subject, args.word, args.sen_type,
                                                             experiment=args.experiment, proc=args.proc)

    fname_raw = DATA_FNAME.format(save_dir, args.subject, args.sen_type, args.word, 'raw')

    sio.savemat(fname_raw, mdict={'evokeds': evokeds, 'labels': labels, 'time': time, 'sentence_ids': sentence_ids})

    avg_data, labels_avg, sentence_ids = load_data.avg_data(evokeds, labels, sentence_ids_raw=sentence_ids, experiment=args.experiment,
                                                            num_instances=args.num_instances, reps_to_use=args.reps_to_use)

    fname_avg = DATA_FNAME.format(save_dir, args.subject, args.sen_type, args.word, 'avg' + str(args.reps_to_use))

    sio.savemat(fname_avg, mdict={'avg_data': avg_data, 'labels_avg': labels_avg, 'time': time, 'sentence_ids_avg': sentence_ids})

    sorted_inds, sorted_reg = sort_sensors()

    sio.savemat(SENSOR_FNAME.format(save_dir), mdict={'sorted_inds': sorted_inds, 'sorted_reg': sorted_reg})