Ejemplo n.º 1
0
def get_data_by_balanced_folds(ASs, fold_idxs, required_num_samples=None):
    prev_autonomous_systems = global_vars.get('autonomous_systems')
    folds = {i: {'X_train': [], 'X_test': [], 'y_train': [], 'y_test': []} for i in range(global_vars.get('n_folds'))}
    for AS in ASs:
        global_vars.set('autonomous_systems', [AS])
        dataset = get_dataset('all')
        concat_train_val_sets(dataset)
        dataset = unify_dataset(dataset)
        if np.count_nonzero(dataset.X) == 0:
            print(f'dropped AS {AS} - no common handovers')
            continue
        try:
            if required_num_samples is not None:
                assert len(dataset.X) == required_num_samples
            for fold_idx in range(global_vars.get('n_folds')):
                folds[fold_idx]['X_train'].extend(dataset.X[fold_idxs[fold_idx]['train_idxs']])
                folds[fold_idx]['X_test'].extend(dataset.X[fold_idxs[fold_idx]['test_idxs']])
                folds[fold_idx]['y_train'].extend(dataset.y[fold_idxs[fold_idx]['train_idxs']])
                folds[fold_idx]['y_test'].extend(dataset.y[fold_idxs[fold_idx]['test_idxs']])
        except IndexError:
            print(f'dropped AS {AS}')
        except AssertionError:
            print(f'dropped AS {AS}')
    for key in folds.keys():
        for inner_key in folds[key].keys():
            folds[key][inner_key] = np.stack(folds[key][inner_key], axis=0)
    global_vars.set('autonomous_systems', prev_autonomous_systems)
    return folds
Ejemplo n.º 2
0
def performance_frequency_report(pretrained_model, dataset, folder_name):
    report_file_name = f'{folder_name}/{global_vars.get("report")}_{global_vars.get("band_filter").__name__}.pdf'
    if os.path.isfile(report_file_name):
        return
    baselines = []
    freq_models = pretrain_model_on_filtered_data(pretrained_model, global_vars.get('low_freq'),
                                                  global_vars.get('high_freq'))
    all_performances = []
    all_performances_freq = []
    for subject in global_vars.get('subjects_to_check'):
        single_subj_performances = []
        single_subj_performances_freq = []
        single_subj_dataset = get_dataset(subject)
        baselines.append(evaluate_single_model(pretrained_model, single_subj_dataset['test'].X,
                                                     single_subj_dataset['test'].y,
                                                     eval_func=get_eval_function()))
        for freq in range(global_vars.get('low_freq'), global_vars.get('high_freq') + 1):
            single_subj_dataset_freq = deepcopy(single_subj_dataset)
            for section in ['train', 'valid', 'test']:
                single_subj_dataset_freq[section].X = global_vars.get('band_filter')\
                    (single_subj_dataset_freq[section].X, max(1, freq - 1), freq + 1, global_vars.get('frequency')).astype(np.float32)
            pretrained_model_copy_freq = deepcopy(freq_models[freq])
            single_subj_performances.append(evaluate_single_model(pretrained_model, single_subj_dataset_freq['test'].X,
                                                                  single_subj_dataset['test'].y, get_eval_function()))
            single_subj_performances_freq.append(evaluate_single_model(pretrained_model_copy_freq, single_subj_dataset_freq['test'].X,
                                                                  single_subj_dataset['test'].y, get_eval_function()))
        all_performances.append(single_subj_performances)
        all_performances_freq.append(single_subj_performances_freq)
    baselines.append(np.average(baselines, axis=0))
    all_performances.append(np.average(all_performances, axis=0))
    all_performances_freq.append(np.average(all_performances_freq, axis=0))
    export_performance_frequency_to_csv(all_performances, all_performances_freq, baselines, folder_name)
    performance_plot_imgs = plot_performance_frequency([all_performances, all_performances_freq], baselines,
                                                       legend=['no retraining', 'with retraining', 'unperturbed'])
    for subj_idx in range(len(all_performances)):
        for perf_idx in range(len(all_performances[subj_idx])):
            if subj_idx == len(all_performances) - 1:
                subj_str = 'avg'
            else:
                subj_str = subj_idx
            global_vars.get('sacred_ex').log_scalar(f'subject {subj_str} no retrain', all_performances[subj_idx][perf_idx],
                                                    global_vars.get('low_freq') + perf_idx)
            global_vars.get('sacred_ex').log_scalar(f'subject {subj_str} retrain', all_performances_freq[subj_idx][perf_idx],
                                                    global_vars.get('low_freq') + perf_idx)
            global_vars.get('sacred_ex').log_scalar(f'subject {subj_str} baseline', baselines[subj_idx],
                                                    global_vars.get('low_freq') + perf_idx)
    story = [get_image(tf) for tf in performance_plot_imgs]
    create_pdf_from_story(report_file_name, story)
    for tf in performance_plot_imgs:
        os.remove(tf)
Ejemplo n.º 3
0
def per_subject_exp(exp_name, csv_file, subjects):
    stop_criterion, iterator, loss_function, monitors = get_settings()
    with open(csv_file, 'a', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=FIELDNAMES)
        writer.writeheader()
    for subject_id in subjects:
        train_set, val_set, test_set = {}, {}, {}
        if global_vars.get('pure_cross_subject'):
            dataset = get_dataset('all')
        else:
            dataset = get_dataset(subject_id)
        train_set[subject_id], val_set[subject_id], test_set[subject_id] = dataset['train'], dataset['valid'], dataset['test']
        evolution_file = f'results/{exp_name}/subject_{subject_id}_archs.txt'
        eegnas = EEGNAS_evolution(iterator=iterator, exp_folder=f"results/{exp_name}", exp_name=exp_name,
                            train_set=train_set, val_set=val_set, test_set=test_set,
                            stop_criterion=stop_criterion, monitors=monitors, loss_function=loss_function,
                            subject_id=subject_id, fieldnames=FIELDNAMES, strategy='per_subject',
                            evolution_file=evolution_file, csv_file=csv_file)
        if global_vars.get('deap'):
            best_model_filename = eegnas.evolution_deap()
        else:
            best_model_filename = eegnas.evolution()
        if global_vars.get('pure_cross_subject') or len(subjects) == 1:
            return [best_model_filename]
Ejemplo n.º 4
0
def get_fold_idxs(AS):
    if global_vars.get('k_fold_time'):
        kf = TimeSeriesSplit(n_splits=global_vars.get('n_folds'))
    else:
        kf = KFold(n_splits=global_vars.get('n_folds'), shuffle=True)
    prev_autonomous_systems = global_vars.get('autonomous_systems')
    global_vars.set('autonomous_systems', [AS])
    dataset = get_dataset('all')
    concat_train_val_sets(dataset)
    dataset = unify_dataset(dataset)
    fold_idxs = {i: {} for i in range(global_vars.get('n_folds'))}
    for fold_num, (train_index, test_index) in enumerate(kf.split(list(range(len(dataset.X))))):
        fold_idxs[fold_num]['train_idxs'] = train_index
        fold_idxs[fold_num]['test_idxs'] = test_index
    global_vars.set('autonomous_systems', prev_autonomous_systems)
    return fold_idxs
Ejemplo n.º 5
0
def cross_subject_exp(exp_name, csv_file, subjects):
    stop_criterion, iterator, loss_function, monitors = get_settings()
    if not global_vars.get('cross_subject_sampling_rate'):
        global_vars.set('cross_subject_sampling_rate', len(subjects))
    with open(csv_file, 'a', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=FIELDNAMES)
        writer.writeheader()
    train_set_all = {}
    val_set_all = {}
    test_set_all = {}
    for subject_id in subjects:
        dataset = get_dataset(subject_id)
        train_set_all[subject_id], val_set_all[subject_id], test_set_all[subject_id] = dataset['train'],\
                                                                                       dataset['valid'], dataset['test']
    evolution_file = f'results/{exp_name}/archs.txt'
    eegnas = EEGNAS_evolution(iterator=iterator, exp_folder=f"results/{exp_name}", exp_name=exp_name,
                              train_set=train_set_all, val_set=val_set_all, test_set=test_set_all,
                              stop_criterion=stop_criterion, monitors=monitors, loss_function=loss_function,
                              subject_id=subject_id, fieldnames=FIELDNAMES, strategy='cross_subject',
                              evolution_file=evolution_file, csv_file=csv_file)
    return [eegnas.evolution()]
Ejemplo n.º 6
0
        if global_vars.get('model_alias') == 'nsga' and global_vars.get(
                'explainer') == 'integrated_gradients':
            continue
        if global_vars.get('model_alias'):
            alias = global_vars.get('model_alias')
            global_vars.set('models_dir', MODEL_ALIASES[alias][0])
            global_vars.set('model_name', MODEL_ALIASES[alias][1])
        global_vars.set('band_filter', {
            'pass': butter_bandpass_filter,
            'stop': butter_bandstop_filter
        }[global_vars.get('band_filter')])

        set_params_by_dataset('../configurations/dataset_params.ini')
        subject_id = global_vars.get('subject_id')
        if global_vars.get('model_alias') != 'Ensemble':
            dataset = get_dataset(subject_id)
        exp_name = f"{exp_id}_{index+1}_{global_vars.get('report')}_{global_vars.get('dataset')}"
        exp_name = add_params_to_name(exp_name, multiple_values)
        stop_criterion, iterator, loss_function, monitors = get_normal_settings(
        )
        trainer = NN_Trainer(iterator, loss_function, stop_criterion, monitors)

        if global_vars.get('model_name') == 'rnn':
            model = MultivariateLSTM(dataset['train'].X.shape[1],
                                     100,
                                     global_vars.get('batch_size'),
                                     global_vars.get('input_height'),
                                     global_vars.get('n_classes'),
                                     eegnas=True)
        elif global_vars.get('model_name') == 'MHANet':
            model = MHANetModel(dataset['train'].X.shape[1],