예제 #1
0
def reset_model(checkpoint):
    # Load the state dict of the model.
    model.network.load_state_dict(checkpoint['model_state_dict'])

    # Only optimize parameters that requires gradient.
    optimizer = AdamW(filter(lambda p: p.requires_grad,
                             model.network.parameters()),
                      lr=1 * 0.01,
                      weight_decay=0.5 * 0.001)
    model.compile(
        loss=F.nll_loss,
        optimizer=optimizer,
        iterator_seed=20200205,
    )
예제 #2
0
def network_model(model, train_set, test_set, valid_set, n_chans, input_time_length, cuda):
	
	max_epochs = 30 
	max_increase_epochs = 10 
	batch_size = 64 
	init_block_size = 1000

	set_random_seeds(seed=20190629, cuda=cuda)

	n_classes = 2 
	n_chans = n_chans
	input_time_length = input_time_length

	if model == 'deep':
		model = Deep4Net(n_chans, n_classes, input_time_length=input_time_length,
						 final_conv_length='auto').create_network()

	elif model == 'shallow':
		model = ShallowFBCSPNet(n_chans, n_classes, input_time_length=input_time_length,
								final_conv_length='auto').create_network()

	if cuda:
		model.cuda()

	log.info("%s model: ".format(str(model))) 

	optimizer = AdamW(model.parameters(), lr=0.00625, weight_decay=0)

	iterator = BalancedBatchSizeIterator(batch_size=batch_size) 

	stop_criterion = Or([MaxEpochs(max_epochs),
						 NoDecrease('valid_misclass', max_increase_epochs)])

	monitors = [LossMonitor(), MisclassMonitor(), RuntimeMonitor()]

	model_constraint = None
	print(train_set.X.shape[0]) 

	model_test = Experiment(model, train_set, valid_set, test_set, iterator=iterator,
							loss_function=F.nll_loss, optimizer=optimizer,
							model_constraint=model_constraint, monitors=monitors,
							stop_criterion=stop_criterion, remember_best_column='valid_misclass',
							run_after_early_stop=True, cuda=cuda)

	model_test.run()
	return model_test 
예제 #3
0
파일: deep4.py 프로젝트: kahartma/eeggan
def train(train_set, test_set, model, iterator, monitors, loss_function,
          max_epochs, cuda):
    if cuda:
        model.cuda()

    optimizer = AdamW(model.parameters(),
                      lr=1 * 0.01,
                      weight_decay=0.5 *
                      0.001)  # these are good values for the deep model

    stop_criterion = MaxEpochs(max_epochs)
    model_constraint = MaxNormDefaultConstraint()

    n_updates_per_epoch = sum(
        [1 for _ in iterator.get_batches(train_set, shuffle=True)])
    n_updates_per_period = n_updates_per_epoch * max_epochs
    scheduler = CosineAnnealing(n_updates_per_period)
    optimizer = ScheduledOptimizer(scheduler,
                                   optimizer,
                                   schedule_weight_decay=True)

    exp = Experiment(model,
                     train_set,
                     None,
                     test_set,
                     iterator=iterator,
                     loss_function=loss_function,
                     optimizer=optimizer,
                     model_constraint=model_constraint,
                     monitors=monitors,
                     remember_best_column=None,
                     stop_criterion=stop_criterion,
                     cuda=cuda,
                     run_after_early_stop=False,
                     do_early_stop=False)
    exp.run()
    return exp
예제 #4
0
else: # cropped
    if model_type == 'shallow':
        model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes,
                            input_time_length=None,
                            final_conv_length=1)
    else:
        model = Deep4Net(in_chans=in_chans, n_classes=n_classes,
                            input_time_length=None,
                            final_conv_length=1)
if cuda:
    model.cuda()

from braindecode.torch_ext.optimizers import AdamW
import torch.nn.functional as F
if model_type == 'shallow':
    optimizer = AdamW(model.parameters(), lr=0.0625 * 0.01, weight_decay=0)
else:
    optimizer = AdamW(model.parameters(), lr=1*0.01, weight_decay=0.5*0.001) # these are good values for the deep model

if train_type == 'trialwise' :
    model.compile(loss=F.nll_loss, optimizer=optimizer, iterator_seed=1)
else: # cropped 
    model.compile(loss=F.nll_loss, optimizer=optimizer, iterator_seed=1, cropped=True)

# Compile model exactly the same way as when you trained it

print("INFO : Epochs: {}".format(epoches))
print("INFO : Batch Size: {}".format(batch_size))


# Fit model exactly the same way as when you trained it (omit any optional params though)
예제 #5
0
                            pool_mode='mean',
                            split_first_layer=True,
                            batch_norm=True,
                            batch_norm_alpha=0.1,
                            drop_prob=0.5)

    if cuda:
        model.cuda()

    ###########################################################################
    ### (5) Create cropped iterator ###########################################
    ###########################################################################
    from braindecode.torch_ext.optimizers import AdamW
    import torch.nn.functional as F
    #optimizer = AdamW(model.parameters(), lr=1*0.01, weight_decay=0.5*0.001) # these are good values for the deep model
    optimizer = AdamW(model.parameters(), lr=0.0625 * 0.01, weight_decay=0)
    model.compile(loss=F.nll_loss,
                  optimizer=optimizer,
                  iterator_seed=1,
                  cropped=True)

    ###########################################################################
    ### (6) Run the training ##################################################
    ###########################################################################
    #input_time_length = 1*Fs
    input_time_length = 200
    model.fit(
        train_set.X,
        train_set.y,
        epochs=250,
        batch_size=64,
예제 #6
0
def network_model(subject_id, model_type, data_type, cropped, cuda, parameters, hyp_params):
	best_params = dict() # dictionary to store hyper-parameter values

	#####Parameter passed to funciton#####
	max_epochs  = parameters['max_epochs']
	max_increase_epochs = parameters['max_increase_epochs']
	batch_size = parameters['batch_size']

	#####Constant Parameters#####
	best_loss = 100.0 # instatiate starting point for loss
	iterator = BalancedBatchSizeIterator(batch_size=batch_size)
	stop_criterion = Or([MaxEpochs(max_epochs),
						 NoDecrease('valid_misclass', max_increase_epochs)])
	monitors = [LossMonitor(), MisclassMonitor(), RuntimeMonitor()]
	model_constraint = MaxNormDefaultConstraint()
	epoch = 4096

	#####Collect and format data#####
	if data_type == 'words':
		data, labels = format_data(data_type, subject_id, epoch)
		data = data[:,:,768:1280] # within-trial window selected for classification
	elif data_type == 'vowels':
		data, labels = format_data(data_type, subject_id, epoch)
		data = data[:,:,512:1024]
	elif data_type == 'all_classes':
		data, labels = format_data(data_type, subject_id, epoch)
		data = data[:,:,768:1280]
	
	x = lambda a: a * 1e6 # improves numerical stability
	data = x(data)
	
	data = normalize(data)
	data, labels = balanced_subsample(data, labels) # downsampling the data to ensure equal classes
	data, _, labels, _ = train_test_split(data, labels, test_size=0, random_state=42) # redundant shuffle of data/labels

	#####model inputs#####
	unique, counts = np.unique(labels, return_counts=True)
	n_classes = len(unique)
	n_chans   = int(data.shape[1])
	input_time_length = data.shape[2]

	#####k-fold nested corss-validation#####
	num_folds = 4
	skf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=10)
	out_fold_num = 0 # outer-fold number
	
	cv_scores = []
	#####Outer=Fold#####
	for inner_ind, outer_index in skf.split(data, labels):
		inner_fold, outer_fold     = data[inner_ind], data[outer_index]
		inner_labels, outer_labels = labels[inner_ind], labels[outer_index]
		out_fold_num += 1
		 # list for storing cross-validated scores
		loss_with_params = dict()# for storing param values and losses
		in_fold_num = 0 # inner-fold number
		
		#####Inner-Fold#####
		for train_idx, valid_idx in skf.split(inner_fold, inner_labels):
			X_Train, X_val = inner_fold[train_idx], inner_fold[valid_idx]
			y_train, y_val = inner_labels[train_idx], inner_labels[valid_idx]
			in_fold_num += 1
			train_set = SignalAndTarget(X_Train, y_train)
			valid_set = SignalAndTarget(X_val, y_val)
			loss_with_params[f"Fold_{in_fold_num}"] = dict()
			
			####Nested cross-validation#####
			for drop_prob in hyp_params['drop_prob']:
				for loss_function in hyp_params['loss']:
					for i in range(len(hyp_params['lr_adam'])):
						model = None # ensure no duplication of models
						# model, learning-rate and optimizer setup according to model_type
						if model_type == 'shallow':
							model =  ShallowFBCSPNet(in_chans=n_chans, n_classes=n_classes, input_time_length=input_time_length,
										 n_filters_time=80, filter_time_length=40, n_filters_spat=80, 
										 pool_time_length=75, pool_time_stride=25, final_conv_length='auto',
										 conv_nonlin=square, pool_mode='max', pool_nonlin=safe_log, 
										 split_first_layer=True, batch_norm=True, batch_norm_alpha=0.1,
										 drop_prob=drop_prob).create_network()
							lr = hyp_params['lr_ada'][i]
							optimizer = optim.Adadelta(model.parameters(), lr=lr, rho=0.9, weight_decay=0.1, eps=1e-8)
						elif model_type == 'deep':
							model = Deep4Net(in_chans=n_chans, n_classes=n_classes, input_time_length=input_time_length,
										 final_conv_length='auto', n_filters_time=20, n_filters_spat=20, filter_time_length=10,
										 pool_time_length=3, pool_time_stride=3, n_filters_2=50, filter_length_2=15,
										 n_filters_3=100, filter_length_3=15, n_filters_4=400, filter_length_4=10,
										 first_nonlin=leaky_relu, first_pool_mode='max', first_pool_nonlin=safe_log, later_nonlin=leaky_relu,
										 later_pool_mode='max', later_pool_nonlin=safe_log, drop_prob=drop_prob, 
										 double_time_convs=False, split_first_layer=False, batch_norm=True, batch_norm_alpha=0.1,
										 stride_before_pool=False).create_network() #filter_length_4 changed from 15 to 10
							lr = hyp_params['lr_ada'][i]
							optimizer = optim.Adadelta(model.parameters(), lr=lr, weight_decay=0.1, eps=1e-8)
						elif model_type == 'eegnet':
							model = EEGNetv4(in_chans=n_chans, n_classes=n_classes, final_conv_length='auto', 
										 input_time_length=input_time_length, pool_mode='mean', F1=16, D=2, F2=32,
										 kernel_length=64, third_kernel_size=(8,4), drop_prob=drop_prob).create_network()
							lr = hyp_params['lr_adam'][i]
							optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0, eps=1e-8, amsgrad=False)
						
						set_random_seeds(seed=20190629, cuda=cuda)
						
						if cuda:
							model.cuda()
							torch.backends.cudnn.deterministic = True
						model = torch.nn.DataParallel(model)
						log.info("%s model: ".format(str(model)))

						loss_function = loss_function
						model_loss_function = None

						#####Setup to run the selected model#####
						model_test = Experiment(model, train_set, valid_set, test_set=None, iterator=iterator,
												loss_function=loss_function, optimizer=optimizer,
												model_constraint=model_constraint, monitors=monitors,
												stop_criterion=stop_criterion, remember_best_column='valid_misclass',
												run_after_early_stop=True, model_loss_function=model_loss_function, cuda=cuda,
												data_type=data_type, subject_id=subject_id, model_type=model_type, 
												cropped=cropped, model_number=str(out_fold_num)) 

						model_test.run()
						model_loss = model_test.epochs_df['valid_loss'].astype('float')
						current_val_loss = current_loss(model_loss)
						loss_with_params[f"Fold_{in_fold_num}"][f"{drop_prob}/{loss_function}/{lr}"] = current_val_loss

		####Select and train optimized model#####
		df = pd.DataFrame(loss_with_params)
		df['mean'] = df.mean(axis=1) # compute mean loss across k-folds
		writer_df = f"results_folder\\results\\S{subject_id}\\{model_type}_parameters.xlsx"
		df.to_excel(writer_df)
		
		best_dp, best_loss, best_lr = df.loc[df['mean'].idxmin()].__dict__['_name'].split("/") # extract best param values
		if str(best_loss[10:13]) == 'nll':
			best_loss = F.nll_loss
		elif str(best_loss[10:13]) == 'cro':
			best_loss = F.cross_entropy
		
		print(f"Best parameters: dropout: {best_dp}, loss: {str(best_loss)[10:13]}, lr: {best_lr}")

		#####Train model on entire inner fold set#####
		torch.backends.cudnn.deterministic = True
		model = None
		#####Create outer-fold validation and test sets#####
		X_valid, X_test, y_valid, y_test = train_test_split(outer_fold, outer_labels, test_size=0.5, random_state=42, stratify=outer_labels)
		train_set = SignalAndTarget(inner_fold, inner_labels)
		valid_set = SignalAndTarget(X_valid, y_valid)
		test_set  = SignalAndTarget(X_test, y_test)


		if model_type == 'shallow':
			model =  ShallowFBCSPNet(in_chans=n_chans, n_classes=n_classes, input_time_length=input_time_length,
						 n_filters_time=60, filter_time_length=5, n_filters_spat=40, 
						 pool_time_length=50, pool_time_stride=15, final_conv_length='auto',
						 conv_nonlin=relu6, pool_mode='mean', pool_nonlin=safe_log, 
						 split_first_layer=True, batch_norm=True, batch_norm_alpha=0.1,
						 drop_prob=0.1).create_network() #50 works better than 75
			
			optimizer = optim.Adadelta(model.parameters(), lr=2.0, rho=0.9, weight_decay=0.1, eps=1e-8) 
			
		elif model_type == 'deep':
			model = Deep4Net(in_chans=n_chans, n_classes=n_classes, input_time_length=input_time_length,
						 final_conv_length='auto', n_filters_time=20, n_filters_spat=20, filter_time_length=5,
						 pool_time_length=3, pool_time_stride=3, n_filters_2=20, filter_length_2=5,
						 n_filters_3=40, filter_length_3=5, n_filters_4=1500, filter_length_4=10,
						 first_nonlin=leaky_relu, first_pool_mode='mean', first_pool_nonlin=safe_log, later_nonlin=leaky_relu,
						 later_pool_mode='mean', later_pool_nonlin=safe_log, drop_prob=0.1, 
						 double_time_convs=False, split_first_layer=True, batch_norm=True, batch_norm_alpha=0.1,
						 stride_before_pool=False).create_network()
			
			optimizer = AdamW(model.parameters(), lr=0.1, weight_decay=0)
		elif model_type == 'eegnet':
			model = EEGNetv4(in_chans=n_chans, n_classes=n_classes, final_conv_length='auto', 
						 input_time_length=input_time_length, pool_mode='mean', F1=16, D=2, F2=32,
						 kernel_length=64, third_kernel_size=(8,4), drop_prob=0.1).create_network()
			optimizer = optim.Adam(model.parameters(), lr=0.1, weight_decay=0, eps=1e-8, amsgrad=False) 
			

		if cuda:
			model.cuda()
			torch.backends.cudnn.deterministic = True
			#model = torch.nn.DataParallel(model)
		
		log.info("Optimized model")
		model_loss_function=None
		
		#####Setup to run the optimized model#####
		optimized_model = op_exp(model, train_set, valid_set, test_set=test_set, iterator=iterator,
								loss_function=best_loss, optimizer=optimizer,
								model_constraint=model_constraint, monitors=monitors,
								stop_criterion=stop_criterion, remember_best_column='valid_misclass',
								run_after_early_stop=True, model_loss_function=model_loss_function, cuda=cuda,
								data_type=data_type, subject_id=subject_id, model_type=model_type, 
								cropped=cropped, model_number=str(out_fold_num))
		optimized_model.run()

		log.info("Last 5 epochs")
		log.info("\n" + str(optimized_model.epochs_df.iloc[-5:]))
		
		writer = f"results_folder\\results\\S{subject_id}\\{data_type}_{model_type}_{str(out_fold_num)}.xlsx"
		optimized_model.epochs_df.iloc[-30:].to_excel(writer)

		accuracy = 1 - np.min(np.array(optimized_model.class_acc))
		cv_scores.append(accuracy) # k accuracy scores for this param set. 
		
	#####Print and store fold accuracies and mean accuracy#####
	
	print(f"Class Accuracy: {np.mean(np.array(cv_scores))}")
	results_df = pd.DataFrame(dict(cv_scores=cv_scores,
								   cv_mean=np.mean(np.array(cv_scores))))

	writer2 = f"results_folder\\results\\S{subject_id}\\{data_type}_{model_type}_cvscores.xlsx"
	results_df.to_excel(writer2)
	return optimized_model, np.mean(np.array(cv_scores))
예제 #7
0
            #select only certain trials (left,right hand imagery)
            # indices = np.where((y == 0) | (y == 1))
            # X = np.take(X, indices[0], axis=0)
            # y = np.take(y, indices[0])

            # del FeatVect, y_labels, labels, raw_data

            train_set = SignalAndTarget(X, y=y)
            # train_set, valid_set, test_set = split_into_train_valid_test(train_set, 3, 2, rng=RandomState((2019,22,3)))
            # train_set, test_set = split_into_two_sets(train_set, first_set_fraction=0.8)

            print('data prepared')

            optimizer = AdamW(model.parameters(),
                              lr=0.01,
                              weight_decay=0.1 * 0.001)
            print('optimizer created')

            #if model_created:
            #    model.network.load_state_dict(th.load(save_path))
            #    model.network.train()
            model_created = True

            model.compile(loss=F.nll_loss,
                          optimizer=optimizer,
                          iterator_seed=1,
                          cropped=True)
            print('compiled')

            model.fit(train_set.X,
예제 #8
0
def get_optmizer(model, config):
    # only include parameters that require grad (i.e. are not frozen)
    return AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                 lr=config['optimizer']['lr'],
                 weight_decay=config['optimizer']['weight_decay'])
예제 #9
0
def run_exp(
    debug,
    subject_id,
    max_epochs,
    n_sensors,
    final_hz,
    half_before,
    start_ms,
    stop_ms,
    model,
    weight_decay,
    final_fft,
    add_bnorm,
    act_norm,
):
    model_name = model
    del model
    assert final_hz in [64, 256]

    car = not debug
    train_inputs, test_inputs = load_train_test(
        subject_id,
        car,
        n_sensors,
        final_hz,
        start_ms,
        stop_ms,
        half_before,
        only_load_given_sensors=debug,
    )

    cuda = True
    if cuda:
        train_inputs = [i.cuda() for i in train_inputs]
        test_inputs = [i.cuda() for i in test_inputs]

    from braindecode.datautil.signal_target import SignalAndTarget

    sets = []
    for inputs in (train_inputs, test_inputs):
        X = np.concatenate([var_to_np(ins) for ins in inputs]).astype(
            np.float32
        )
        y = np.concatenate(
            [np.ones(len(ins)) * i_class for i_class, ins in enumerate(inputs)]
        )
        y = y.astype(np.int64)
        set = SignalAndTarget(X, y)
        sets.append(set)
    train_set = sets[0]
    valid_set = sets[1]

    from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
    from braindecode.models.deep4 import Deep4Net
    from torch import nn
    from braindecode.torch_ext.util import set_random_seeds

    set_random_seeds(2019011641, cuda)
    n_chans = train_inputs[0].shape[1]
    n_time = train_inputs[0].shape[2]
    n_classes = 2
    input_time_length=train_set.X.shape[2]

    if model_name == 'shallow':
        # final_conv_length = auto ensures we only get a single output in the time dimension
        model = ShallowFBCSPNet(in_chans=n_chans, n_classes=n_classes,
                                input_time_length=input_time_length,
                                final_conv_length='auto')
    elif model_name == 'deep':
        model = Deep4Net(n_chans, n_classes,
                 input_time_length=train_set.X.shape[2],
                 pool_time_length=2,
                 pool_time_stride=2,
                 final_conv_length='auto')
    elif model_name == 'invertible':
        model = InvertibleModel(n_chans, n_time, final_fft=final_fft,
                                add_bnorm=add_bnorm)
    elif model_name == 'deep_invertible':
        n_chan_pad = 0
        filter_length_time = 11
        model = deep_invertible(
            n_chans, input_time_length,  n_chan_pad,  filter_length_time)
        model.add_module("select_dims", Expression(lambda x: x[:, :2, 0]))
        model.add_module("softmax", nn.LogSoftmax(dim=1))
        model = WrappedModel(model)

        ## set scale
        if act_norm:
            model.cuda()
            for module in model.network.modules():
                if hasattr(module, 'log_factor'):
                    module._forward_hooks.clear()
                    module.register_forward_hook(scale_to_unit_var)
            model.network(train_inputs[0].cuda());
            for module in model.network.modules():
                if hasattr(module, 'log_factor'):
                    module._forward_hooks.clear()

    else:
        assert False
    if cuda:
        model.cuda()

    from braindecode.torch_ext.optimizers import AdamW
    import torch.nn.functional as F
    if model_name == 'shallow':
        assert weight_decay == 'hardcoded'
        optimizer = AdamW(model.parameters(), lr=0.0625 * 0.01, weight_decay=0)
    elif model_name == 'deep':
        assert weight_decay == 'hardcoded'
        optimizer = AdamW(model.parameters(), lr=1 * 0.01,
                          weight_decay=0.5 * 0.001)  # these are good values for the deep model
    elif model_name == 'invertible':
        optimizer = AdamW(model.parameters(), lr=1e-4,
                          weight_decay=weight_decay)
    elif model_name == 'deep_invertible':
        optimizer = AdamW(model.parameters(), lr=1 * 0.001,
                          weight_decay=weight_decay)

    else:
        assert False

    model.compile(loss=F.nll_loss, optimizer=optimizer, iterator_seed=1, )
    model.fit(train_set.X, train_set.y, epochs=max_epochs, batch_size=64,
              scheduler='cosine',
              validation_data=(valid_set.X, valid_set.y), )

    return model.epochs_df, model.network
def setup_exp(
        train_folder,
        n_recordings,
        n_chans,
        model_name,
        n_start_chans,
        n_chan_factor,
        input_time_length,
        final_conv_length,
        model_constraint,
        stride_before_pool,
        init_lr,
        batch_size,
        max_epochs,
        cuda,
        num_workers,
        task,
        weight_decay,
        n_folds,
        shuffle_folds,
        lazy_loading,
        eval_folder,
        result_folder,
        run_on_normals,
        run_on_abnormals,
        seed,
        l2_decay,
        gradient_clip,
        ):
    info_msg = "using {}, {}".format(
        os.environ["SLURM_JOB_PARTITION"], os.environ["SLURMD_NODENAME"],)
    info_msg += ", gpu {}".format(os.environ["CUDA_VISIBLE_DEVICES"])
    logging.info(info_msg)

    logging.info("Targets for this task: <{}>".format(task))

    import torch.backends.cudnn as cudnn
    cudnn.benchmark = True

    loss_function = nll_loss_on_mean
    remember_best_column = "valid_misclass"
    n_classes = 2

    if model_constraint is not None:
        assert model_constraint == 'defaultnorm'
        model_constraint = MaxNormDefaultConstraint()

    stop_criterion = MaxEpochs(max_epochs)

    set_random_seeds(seed=seed, cuda=cuda)
    if model_name == 'shallow':
        model = ShallowFBCSPNet(
            in_chans=n_chans, n_classes=n_classes,
            n_filters_time=n_start_chans,
            n_filters_spat=n_start_chans,
            input_time_length=input_time_length,
            final_conv_length=final_conv_length).create_network()
    elif model_name == 'deep':
        model = Deep4Net(
            n_chans, n_classes,
            n_filters_time=n_start_chans,
            n_filters_spat=n_start_chans,
            input_time_length=input_time_length,
            n_filters_2=int(n_start_chans * n_chan_factor),
            n_filters_3=int(n_start_chans * (n_chan_factor ** 2.0)),
            n_filters_4=int(n_start_chans * (n_chan_factor ** 3.0)),
            final_conv_length=final_conv_length,
            stride_before_pool=stride_before_pool).create_network()
    elif model_name == 'eegnet':
        model = EEGNetv4(
            n_chans, n_classes,
            input_time_length=input_time_length,
            final_conv_length=final_conv_length).create_network()
    elif model_name == "tcn":
        model = TemporalConvNet(
            input_size=n_chans,
            output_size=n_classes,
            context_size=0,
            num_channels=55,
            num_levels=5,
            kernel_size=16,
            dropout=0.05270154233150525,
            skip_mode=None,
            use_context=0,
            lasso_selection=0.0,
            rnn_normalization=None)
    else:
        assert False, "unknown model name {:s}".format(model_name)

    # maybe check if this works and wait / re-try after some time?
    # in case of all cuda devices are busy
    if cuda:
        model.cuda()

    if model_name != "tcn":
        to_dense_prediction_model(model)
    logging.info("Model:\n{:s}".format(str(model)))

    test_input = np_to_var(np.ones((2, n_chans, input_time_length, 1),
                                   dtype=np.float32))
    if list(model.parameters())[0].is_cuda:
        test_input = test_input.cuda()
    out = model(test_input)
    n_preds_per_input = out.cpu().data.numpy().shape[2]

    if eval_folder is None:
        logging.info("will do validation")
        if lazy_loading:
            logging.info("using lazy loading to load {} recs"
                         .format(n_recordings))
            dataset = TuhLazy(train_folder, target=task,
                              n_recordings=n_recordings)
        else:
            logging.info("using traditional loading to load {} recs"
                         .format(n_recordings))
            dataset = Tuh(train_folder, n_recordings=n_recordings, target=task)

        assert not (run_on_normals and run_on_abnormals), (
            "decide whether to run on normal or abnormal subjects")
        # only run on normal subjects
        if run_on_normals:
            ids = [i for i in range(len(dataset))
                   if dataset.pathologicals[i] == 0]  # 0 is non-pathological
            dataset = TuhSubset(dataset, ids)
            logging.info("only using {} normal subjects".format(len(dataset)))
        if run_on_abnormals:
            ids = [i for i in range(len(dataset))
                   if dataset.pathologicals[i] == 1]  # 1 is pathological
            dataset = TuhSubset(dataset, ids)
            logging.info("only using {} abnormal subjects".format(len(dataset)))

        indices = np.arange(len(dataset))
        kf = KFold(n_splits=n_folds, shuffle=shuffle_folds)
        for i, (train_ind, test_ind) in enumerate(kf.split(indices)):
            assert len(np.intersect1d(train_ind, test_ind)) == 0, (
                "train and test set overlap!")

            # seed is in range of number of folds and was set by submit script
            if i == seed:
                break

        if lazy_loading:
            test_subset = TuhLazySubset(dataset, test_ind)
            train_subset = TuhLazySubset(dataset, train_ind)
        else:
            test_subset = TuhSubset(dataset, test_ind)
            train_subset = TuhSubset(dataset, train_ind)
    else:
        logging.info("will do final evaluation")
        if lazy_loading:
            train_subset = TuhLazy(train_folder, target=task)
            test_subset = TuhLazy(eval_folder, target=task)
        else:
            train_subset = Tuh(train_folder, target=task)
            test_subset = Tuh(eval_folder, target=task)

        # remove rec:
        # train/abnormal/01_tcp_ar/081/00008184/s001_2011_09_21/00008184_s001_t001
        # since it contains no crop without outliers (channels A1, A2 broken)
        subjects = [f.split("/")[-3] for f in train_subset.file_paths]
        if "00008184" in subjects:
            bad_id = subjects.index("00008184")
            train_subset = remove_file_from_dataset(
                train_subset, file_id=bad_id, file=(
                    "train/abnormal/01_tcp_ar/081/00008184/s001_2011_09_21/"
                    "00008184_s001_t001"))
        subjects = [f.split("/")[-3] for f in test_subset.file_paths]
        if "00008184" in subjects:
            bad_id = subjects.index("00008184")
            test_subset = remove_file_from_dataset(
                test_subset, file_id=bad_id, file=(
                    "train/abnormal/01_tcp_ar/081/00008184/s001_2011_09_21/"
                    "00008184_s001_t001"))

    if lazy_loading:
        iterator = LazyCropsFromTrialsIterator(
            input_time_length, n_preds_per_input, batch_size,
            seed=seed, num_workers=num_workers,
            reset_rng_after_each_batch=False,
            check_preds_smaller_trial_len=False)  # True!
    else:
        iterator = CropsFromTrialsIterator(batch_size, input_time_length,
                                           n_preds_per_input, seed)

    monitors = []
    monitors.append(LossMonitor())
    monitors.append(RAMMonitor())
    monitors.append(RuntimeMonitor())
    monitors.append(CroppedDiagnosisMonitor(input_time_length,
                                            n_preds_per_input))
    monitors.append(LazyMisclassMonitor(col_suffix='sample_misclass'))

    if lazy_loading:
        n_updates_per_epoch = len(iterator.get_batches(train_subset,
                                                       shuffle=False))
    else:
        n_updates_per_epoch = sum([1 for _ in iterator.get_batches(
            train_subset, shuffle=False)])
    n_updates_per_period = n_updates_per_epoch * max_epochs
    logging.info("there are {} updates per epoch".format(n_updates_per_epoch))

    if model_name == "tcn":
        adamw = ExtendedAdam(model.parameters(), lr=init_lr,
                             weight_decay=weight_decay, l2_decay=l2_decay,
                             gradient_clip=gradient_clip)
    else:
        adamw = AdamW(model.parameters(), init_lr,
                                  weight_decay=weight_decay)

    scheduler = CosineAnnealing(n_updates_per_period)
    optimizer = ScheduledOptimizer(scheduler, adamw, schedule_weight_decay=True)

    exp = Experiment(
        model=model,
        train_set=train_subset,
        valid_set=None,
        test_set=test_subset,
        iterator=iterator,
        loss_function=loss_function,
        optimizer=optimizer,
        model_constraint=model_constraint,
        monitors=monitors,
        stop_criterion=stop_criterion,
        remember_best_column=remember_best_column,
        run_after_early_stop=False,
        batch_modifier=None,
        cuda=cuda,
        do_early_stop=False,
        reset_after_second_run=False
    )
    return exp