else: model = Deep4Net(in_chans=in_chans, n_classes=n_classes, input_time_length=None, final_conv_length=1) if cuda: model.cuda() from braindecode.torch_ext.optimizers import AdamW import torch.nn.functional as F if model_type == 'shallow': optimizer = AdamW(model.parameters(), lr=0.0625 * 0.01, weight_decay=0) else: optimizer = AdamW(model.parameters(), lr=1*0.01, weight_decay=0.5*0.001) # these are good values for the deep model if train_type == 'trialwise' : model.compile(loss=F.nll_loss, optimizer=optimizer, iterator_seed=1) else: # cropped model.compile(loss=F.nll_loss, optimizer=optimizer, iterator_seed=1, cropped=True) # Compile model exactly the same way as when you trained it print("INFO : Epochs: {}".format(epoches)) print("INFO : Batch Size: {}".format(batch_size)) # Fit model exactly the same way as when you trained it (omit any optional params though) if train_type == 'trialwise': print(model.fit(train_set.X, train_set.y, epochs=epoches, batch_size=batch_size, scheduler='cosine', validation_data=(valid_set.X, valid_set.y),)) else: # cropped input_time_length = 450
def run_exp( debug, subject_id, max_epochs, n_sensors, final_hz, half_before, start_ms, stop_ms, model, weight_decay, final_fft, add_bnorm, act_norm, ): model_name = model del model assert final_hz in [64, 256] car = not debug train_inputs, test_inputs = load_train_test( subject_id, car, n_sensors, final_hz, start_ms, stop_ms, half_before, only_load_given_sensors=debug, ) cuda = True if cuda: train_inputs = [i.cuda() for i in train_inputs] test_inputs = [i.cuda() for i in test_inputs] from braindecode.datautil.signal_target import SignalAndTarget sets = [] for inputs in (train_inputs, test_inputs): X = np.concatenate([var_to_np(ins) for ins in inputs]).astype( np.float32 ) y = np.concatenate( [np.ones(len(ins)) * i_class for i_class, ins in enumerate(inputs)] ) y = y.astype(np.int64) set = SignalAndTarget(X, y) sets.append(set) train_set = sets[0] valid_set = sets[1] from braindecode.models.shallow_fbcsp import ShallowFBCSPNet from braindecode.models.deep4 import Deep4Net from torch import nn from braindecode.torch_ext.util import set_random_seeds set_random_seeds(2019011641, cuda) n_chans = train_inputs[0].shape[1] n_time = train_inputs[0].shape[2] n_classes = 2 input_time_length=train_set.X.shape[2] if model_name == 'shallow': # final_conv_length = auto ensures we only get a single output in the time dimension model = ShallowFBCSPNet(in_chans=n_chans, n_classes=n_classes, input_time_length=input_time_length, final_conv_length='auto') elif model_name == 'deep': model = Deep4Net(n_chans, n_classes, input_time_length=train_set.X.shape[2], pool_time_length=2, pool_time_stride=2, final_conv_length='auto') elif model_name == 'invertible': model = InvertibleModel(n_chans, n_time, final_fft=final_fft, add_bnorm=add_bnorm) elif model_name == 'deep_invertible': n_chan_pad = 0 filter_length_time = 11 model = deep_invertible( n_chans, input_time_length, n_chan_pad, filter_length_time) model.add_module("select_dims", Expression(lambda x: x[:, :2, 0])) model.add_module("softmax", nn.LogSoftmax(dim=1)) model = WrappedModel(model) ## set scale if act_norm: model.cuda() for module in model.network.modules(): if hasattr(module, 'log_factor'): module._forward_hooks.clear() module.register_forward_hook(scale_to_unit_var) model.network(train_inputs[0].cuda()); for module in model.network.modules(): if hasattr(module, 'log_factor'): module._forward_hooks.clear() else: assert False if cuda: model.cuda() from braindecode.torch_ext.optimizers import AdamW import torch.nn.functional as F if model_name == 'shallow': assert weight_decay == 'hardcoded' optimizer = AdamW(model.parameters(), lr=0.0625 * 0.01, weight_decay=0) elif model_name == 'deep': assert weight_decay == 'hardcoded' optimizer = AdamW(model.parameters(), lr=1 * 0.01, weight_decay=0.5 * 0.001) # these are good values for the deep model elif model_name == 'invertible': optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=weight_decay) elif model_name == 'deep_invertible': optimizer = AdamW(model.parameters(), lr=1 * 0.001, weight_decay=weight_decay) else: assert False model.compile(loss=F.nll_loss, optimizer=optimizer, iterator_seed=1, ) model.fit(train_set.X, train_set.y, epochs=max_epochs, batch_size=64, scheduler='cosine', validation_data=(valid_set.X, valid_set.y), ) return model.epochs_df, model.network