示例#1
0
 def setUp(self):
     args = parse_args(['-e', 'tests', '-c', '../configurations/config.ini'])
     init_config(args.config)
     configs = get_configurations('tests')
     assert(len(configs) == 1)
     global_vars.set_config(configs[0])
     set_params_by_dataset('../configurations/dataset_params.ini')
 def test_get_multiple_values(self):
     args = parse_args([
         '-e', 'test_multiple_values', '-c', '../configurations/config.ini'
     ])
     init_config(args.config)
     configs = get_configurations(args)
     assert (len(configs) == 12)
     for str_conf in [
             'cross_subject', 'num_conv_blocks', 'num_generations'
     ]:
         assert (str_conf in get_multiple_values(configs))
     global_vars.set_config(configs[0])
示例#3
0
    def setUp(self):
        args = parse_args(
            ['-e', 'tests', '-c', '../configurations/config.ini'])
        init_config(args.config)
        configs = get_configurations(args.experiment)
        assert (len(configs) == 1)
        global_vars.set_config(configs[0])
        global_vars.set('eeg_chans', 22)
        global_vars.set('num_subjects', 9)
        global_vars.set('input_time_len', 1125)
        global_vars.set('n_classes', 4)
        set_params_by_dataset()
        input_shape = (50, global_vars.get('eeg_chans'),
                       global_vars.get('input_time_len'))

        class Dummy:
            def __init__(self, X, y):
                self.X = X
                self.y = y

        dummy_data = Dummy(X=np.ones(input_shape, dtype=np.float32),
                           y=np.ones(50, dtype=np.longlong))
        self.iterator = BalancedBatchSizeIterator(
            batch_size=global_vars.get('batch_size'))
        self.loss_function = F.nll_loss
        self.monitors = [
            LossMonitor(),
            MisclassMonitor(),
            GenericMonitor('accuracy', acc_func),
            RuntimeMonitor()
        ]
        self.stop_criterion = Or([
            MaxEpochs(global_vars.get('max_epochs')),
            NoDecrease('valid_misclass',
                       global_vars.get('max_increase_epochs'))
        ])
        self.naiveNAS = NaiveNAS(iterator=self.iterator,
                                 exp_folder='../tests',
                                 exp_name='',
                                 train_set=dummy_data,
                                 val_set=dummy_data,
                                 test_set=dummy_data,
                                 stop_criterion=self.stop_criterion,
                                 monitors=self.monitors,
                                 loss_function=self.loss_function,
                                 config=global_vars.config,
                                 subject_id=1,
                                 fieldnames=None,
                                 model_from_file=None)
示例#4
0
from braindecode.torch_ext.util import set_random_seeds
from EEGNAS_experiment import get_normal_settings, parse_args, get_configurations
from naiveNAS import NaiveNAS
from EEGNAS.models_generation import target_model
from braindecode.torch_ext.util import np_to_var
import numpy as np
from torch import nn
from braindecode.torch_ext.util import var_to_np
import torch as th
import matplotlib

matplotlib.use('qt5agg')

args = parse_args(['-e', 'tests', '-c', '../configurations/config.ini'])
global_vars.init_config(args.config)
configs = get_configurations(args.experiment)
assert (len(configs) == 1)
global_vars.set_config(configs[0])

subject_id = 1
low_cut_hz = 0
fs = 250
valid_set_fraction = 0.2
dataset = 'BCI_IV_2a'
data_folder = '../data/'
global_vars.set('dataset', dataset)
set_params_by_dataset()
global_vars.set('cuda', True)
model_select = 'deep4'
model_dir = '143_x_evolution_layers_cross_subject'
model_name = 'best_model_9_8_6_7_2_1_3_4_5.th'