def get_data_by_balanced_folds(ASs, fold_idxs, required_num_samples=None): prev_autonomous_systems = global_vars.get('autonomous_systems') folds = {i: {'X_train': [], 'X_test': [], 'y_train': [], 'y_test': []} for i in range(global_vars.get('n_folds'))} for AS in ASs: global_vars.set('autonomous_systems', [AS]) dataset = get_dataset('all') concat_train_val_sets(dataset) dataset = unify_dataset(dataset) if np.count_nonzero(dataset.X) == 0: print(f'dropped AS {AS} - no common handovers') continue try: if required_num_samples is not None: assert len(dataset.X) == required_num_samples for fold_idx in range(global_vars.get('n_folds')): folds[fold_idx]['X_train'].extend(dataset.X[fold_idxs[fold_idx]['train_idxs']]) folds[fold_idx]['X_test'].extend(dataset.X[fold_idxs[fold_idx]['test_idxs']]) folds[fold_idx]['y_train'].extend(dataset.y[fold_idxs[fold_idx]['train_idxs']]) folds[fold_idx]['y_test'].extend(dataset.y[fold_idxs[fold_idx]['test_idxs']]) except IndexError: print(f'dropped AS {AS}') except AssertionError: print(f'dropped AS {AS}') for key in folds.keys(): for inner_key in folds[key].keys(): folds[key][inner_key] = np.stack(folds[key][inner_key], axis=0) global_vars.set('autonomous_systems', prev_autonomous_systems) return folds
def performance_frequency_report(pretrained_model, dataset, folder_name): report_file_name = f'{folder_name}/{global_vars.get("report")}_{global_vars.get("band_filter").__name__}.pdf' if os.path.isfile(report_file_name): return baselines = [] freq_models = pretrain_model_on_filtered_data(pretrained_model, global_vars.get('low_freq'), global_vars.get('high_freq')) all_performances = [] all_performances_freq = [] for subject in global_vars.get('subjects_to_check'): single_subj_performances = [] single_subj_performances_freq = [] single_subj_dataset = get_dataset(subject) baselines.append(evaluate_single_model(pretrained_model, single_subj_dataset['test'].X, single_subj_dataset['test'].y, eval_func=get_eval_function())) for freq in range(global_vars.get('low_freq'), global_vars.get('high_freq') + 1): single_subj_dataset_freq = deepcopy(single_subj_dataset) for section in ['train', 'valid', 'test']: single_subj_dataset_freq[section].X = global_vars.get('band_filter')\ (single_subj_dataset_freq[section].X, max(1, freq - 1), freq + 1, global_vars.get('frequency')).astype(np.float32) pretrained_model_copy_freq = deepcopy(freq_models[freq]) single_subj_performances.append(evaluate_single_model(pretrained_model, single_subj_dataset_freq['test'].X, single_subj_dataset['test'].y, get_eval_function())) single_subj_performances_freq.append(evaluate_single_model(pretrained_model_copy_freq, single_subj_dataset_freq['test'].X, single_subj_dataset['test'].y, get_eval_function())) all_performances.append(single_subj_performances) all_performances_freq.append(single_subj_performances_freq) baselines.append(np.average(baselines, axis=0)) all_performances.append(np.average(all_performances, axis=0)) all_performances_freq.append(np.average(all_performances_freq, axis=0)) export_performance_frequency_to_csv(all_performances, all_performances_freq, baselines, folder_name) performance_plot_imgs = plot_performance_frequency([all_performances, all_performances_freq], baselines, legend=['no retraining', 'with retraining', 'unperturbed']) for subj_idx in range(len(all_performances)): for perf_idx in range(len(all_performances[subj_idx])): if subj_idx == len(all_performances) - 1: subj_str = 'avg' else: subj_str = subj_idx global_vars.get('sacred_ex').log_scalar(f'subject {subj_str} no retrain', all_performances[subj_idx][perf_idx], global_vars.get('low_freq') + perf_idx) global_vars.get('sacred_ex').log_scalar(f'subject {subj_str} retrain', all_performances_freq[subj_idx][perf_idx], global_vars.get('low_freq') + perf_idx) global_vars.get('sacred_ex').log_scalar(f'subject {subj_str} baseline', baselines[subj_idx], global_vars.get('low_freq') + perf_idx) story = [get_image(tf) for tf in performance_plot_imgs] create_pdf_from_story(report_file_name, story) for tf in performance_plot_imgs: os.remove(tf)
def per_subject_exp(exp_name, csv_file, subjects): stop_criterion, iterator, loss_function, monitors = get_settings() with open(csv_file, 'a', newline='') as csvfile: writer = csv.DictWriter(csvfile, fieldnames=FIELDNAMES) writer.writeheader() for subject_id in subjects: train_set, val_set, test_set = {}, {}, {} if global_vars.get('pure_cross_subject'): dataset = get_dataset('all') else: dataset = get_dataset(subject_id) train_set[subject_id], val_set[subject_id], test_set[subject_id] = dataset['train'], dataset['valid'], dataset['test'] evolution_file = f'results/{exp_name}/subject_{subject_id}_archs.txt' eegnas = EEGNAS_evolution(iterator=iterator, exp_folder=f"results/{exp_name}", exp_name=exp_name, train_set=train_set, val_set=val_set, test_set=test_set, stop_criterion=stop_criterion, monitors=monitors, loss_function=loss_function, subject_id=subject_id, fieldnames=FIELDNAMES, strategy='per_subject', evolution_file=evolution_file, csv_file=csv_file) if global_vars.get('deap'): best_model_filename = eegnas.evolution_deap() else: best_model_filename = eegnas.evolution() if global_vars.get('pure_cross_subject') or len(subjects) == 1: return [best_model_filename]
def get_fold_idxs(AS): if global_vars.get('k_fold_time'): kf = TimeSeriesSplit(n_splits=global_vars.get('n_folds')) else: kf = KFold(n_splits=global_vars.get('n_folds'), shuffle=True) prev_autonomous_systems = global_vars.get('autonomous_systems') global_vars.set('autonomous_systems', [AS]) dataset = get_dataset('all') concat_train_val_sets(dataset) dataset = unify_dataset(dataset) fold_idxs = {i: {} for i in range(global_vars.get('n_folds'))} for fold_num, (train_index, test_index) in enumerate(kf.split(list(range(len(dataset.X))))): fold_idxs[fold_num]['train_idxs'] = train_index fold_idxs[fold_num]['test_idxs'] = test_index global_vars.set('autonomous_systems', prev_autonomous_systems) return fold_idxs
def cross_subject_exp(exp_name, csv_file, subjects): stop_criterion, iterator, loss_function, monitors = get_settings() if not global_vars.get('cross_subject_sampling_rate'): global_vars.set('cross_subject_sampling_rate', len(subjects)) with open(csv_file, 'a', newline='') as csvfile: writer = csv.DictWriter(csvfile, fieldnames=FIELDNAMES) writer.writeheader() train_set_all = {} val_set_all = {} test_set_all = {} for subject_id in subjects: dataset = get_dataset(subject_id) train_set_all[subject_id], val_set_all[subject_id], test_set_all[subject_id] = dataset['train'],\ dataset['valid'], dataset['test'] evolution_file = f'results/{exp_name}/archs.txt' eegnas = EEGNAS_evolution(iterator=iterator, exp_folder=f"results/{exp_name}", exp_name=exp_name, train_set=train_set_all, val_set=val_set_all, test_set=test_set_all, stop_criterion=stop_criterion, monitors=monitors, loss_function=loss_function, subject_id=subject_id, fieldnames=FIELDNAMES, strategy='cross_subject', evolution_file=evolution_file, csv_file=csv_file) return [eegnas.evolution()]
if global_vars.get('model_alias') == 'nsga' and global_vars.get( 'explainer') == 'integrated_gradients': continue if global_vars.get('model_alias'): alias = global_vars.get('model_alias') global_vars.set('models_dir', MODEL_ALIASES[alias][0]) global_vars.set('model_name', MODEL_ALIASES[alias][1]) global_vars.set('band_filter', { 'pass': butter_bandpass_filter, 'stop': butter_bandstop_filter }[global_vars.get('band_filter')]) set_params_by_dataset('../configurations/dataset_params.ini') subject_id = global_vars.get('subject_id') if global_vars.get('model_alias') != 'Ensemble': dataset = get_dataset(subject_id) exp_name = f"{exp_id}_{index+1}_{global_vars.get('report')}_{global_vars.get('dataset')}" exp_name = add_params_to_name(exp_name, multiple_values) stop_criterion, iterator, loss_function, monitors = get_normal_settings( ) trainer = NN_Trainer(iterator, loss_function, stop_criterion, monitors) if global_vars.get('model_name') == 'rnn': model = MultivariateLSTM(dataset['train'].X.shape[1], 100, global_vars.get('batch_size'), global_vars.get('input_height'), global_vars.get('n_classes'), eegnas=True) elif global_vars.get('model_name') == 'MHANet': model = MHANetModel(dataset['train'].X.shape[1],