コード例 #1
0
 def setUp(self):
     args = parse_args(['-e', 'tests', '-c', '../configurations/config.ini'])
     init_config(args.config)
     configs = get_configurations('tests')
     assert(len(configs) == 1)
     global_vars.set_config(configs[0])
     set_params_by_dataset('../configurations/dataset_params.ini')
コード例 #2
0
 def test_get_multiple_values(self):
     args = parse_args([
         '-e', 'test_multiple_values', '-c', '../configurations/config.ini'
     ])
     init_config(args.config)
     configs = get_configurations(args)
     assert (len(configs) == 12)
     for str_conf in [
             'cross_subject', 'num_conv_blocks', 'num_generations'
     ]:
         assert (str_conf in get_multiple_values(configs))
     global_vars.set_config(configs[0])
コード例 #3
0
ファイル: test_naiveNAS.py プロジェクト: erap129/EEGNAS
    def setUp(self):
        args = parse_args(
            ['-e', 'tests', '-c', '../configurations/config.ini'])
        init_config(args.config)
        configs = get_configurations(args.experiment)
        assert (len(configs) == 1)
        global_vars.set_config(configs[0])
        global_vars.set('eeg_chans', 22)
        global_vars.set('num_subjects', 9)
        global_vars.set('input_time_len', 1125)
        global_vars.set('n_classes', 4)
        set_params_by_dataset()
        input_shape = (50, global_vars.get('eeg_chans'),
                       global_vars.get('input_time_len'))

        class Dummy:
            def __init__(self, X, y):
                self.X = X
                self.y = y

        dummy_data = Dummy(X=np.ones(input_shape, dtype=np.float32),
                           y=np.ones(50, dtype=np.longlong))
        self.iterator = BalancedBatchSizeIterator(
            batch_size=global_vars.get('batch_size'))
        self.loss_function = F.nll_loss
        self.monitors = [
            LossMonitor(),
            MisclassMonitor(),
            GenericMonitor('accuracy', acc_func),
            RuntimeMonitor()
        ]
        self.stop_criterion = Or([
            MaxEpochs(global_vars.get('max_epochs')),
            NoDecrease('valid_misclass',
                       global_vars.get('max_increase_epochs'))
        ])
        self.naiveNAS = NaiveNAS(iterator=self.iterator,
                                 exp_folder='../tests',
                                 exp_name='',
                                 train_set=dummy_data,
                                 val_set=dummy_data,
                                 test_set=dummy_data,
                                 stop_criterion=self.stop_criterion,
                                 monitors=self.monitors,
                                 loss_function=self.loss_function,
                                 config=global_vars.config,
                                 subject_id=1,
                                 fieldnames=None,
                                 model_from_file=None)
コード例 #4
0
from naiveNAS import NaiveNAS
from EEGNAS.models_generation import target_model
from braindecode.torch_ext.util import np_to_var
import numpy as np
from torch import nn
from braindecode.torch_ext.util import var_to_np
import torch as th
import matplotlib

matplotlib.use('qt5agg')

args = parse_args(['-e', 'tests', '-c', '../configurations/config.ini'])
global_vars.init_config(args.config)
configs = get_configurations(args.experiment)
assert (len(configs) == 1)
global_vars.set_config(configs[0])

subject_id = 1
low_cut_hz = 0
fs = 250
valid_set_fraction = 0.2
dataset = 'BCI_IV_2a'
data_folder = '../data/'
global_vars.set('dataset', dataset)
set_params_by_dataset()
global_vars.set('cuda', True)
model_select = 'deep4'
model_dir = '143_x_evolution_layers_cross_subject'
model_name = 'best_model_9_8_6_7_2_1_3_4_5.th'
train_set = {}
val_set = {}
コード例 #5
0
ファイル: EEGNAS_experiment.py プロジェクト: erap129/EEGNAS
    low_cut_hz = 0
    valid_set_fraction = 0.2
    listen()
    exp_id = get_exp_id('results')
    exp_funcs = {'per_subject': per_subject_exp,
                 'leave_one_out': leave_one_out_exp,
                 'cross_subject': cross_subject_exp}

    experiments = args.experiment.split(',')
    all_exps = defaultdict(list)
    first_run = True
    for experiment in experiments:
        configurations = get_configurations(experiment, global_vars.configs)
        multiple_values = get_multiple_values(configurations)
        for index, configuration in enumerate(configurations):
            global_vars.set_config(configuration)
            if index+1 < global_vars.get('start_exp_idx'):
                continue
            if global_vars.get('exp_id'):
                exp_id = global_vars.get('exp_id')
            configuration['DEFAULT']['exp_id'] = exp_id
            if FIRST_RUN:
                FIRST_DATASET = global_vars.get('dataset')
                if global_vars.get('include_params_folder_name'):
                    multiple_values.extend(global_vars.get('include_params_folder_name'))
                FIRST_RUN = False
            exp_name = f"{exp_id}_{index+1}_{experiment}"
            exp_name = add_params_to_name(exp_name, multiple_values)
            ex.config = {}
            ex.add_config({**configuration, **{'tags': [exp_id]}})
            if len(ex.observers) == 0 and not args.debug_mode:
コード例 #6
0
def set_default_config(path):
    global_vars.init_config(path)
    configurations = get_configurations('default_exp', global_vars.configs)
    global_vars.set_config(configurations[0])