Example #1
0
    def __init__(self, in_chans, n_classes, input_time_length):
        super(HybridNetModule, self).__init__()
        deep_model = Deep4Net(
            in_chans,
            n_classes,
            n_filters_time=20,
            n_filters_spat=30,
            n_filters_2=40,
            n_filters_3=50,
            n_filters_4=60,
            input_time_length=input_time_length,
            final_conv_length=2,
        ).create_network()
        shallow_model = ShallowFBCSPNet(
            in_chans,
            n_classes,
            input_time_length=input_time_length,
            n_filters_time=30,
            n_filters_spat=40,
            filter_time_length=28,
            final_conv_length=29,
        ).create_network()

        reduced_deep_model = nn.Sequential()
        for name, module in deep_model.named_children():
            if name == "conv_classifier":
                new_conv_layer = nn.Conv2d(
                    module.in_channels,
                    60,
                    kernel_size=module.kernel_size,
                    stride=module.stride,
                )
                reduced_deep_model.add_module("deep_final_conv",
                                              new_conv_layer)
                break
            reduced_deep_model.add_module(name, module)

        reduced_shallow_model = nn.Sequential()
        for name, module in shallow_model.named_children():
            if name == "conv_classifier":
                new_conv_layer = nn.Conv2d(
                    module.in_channels,
                    40,
                    kernel_size=module.kernel_size,
                    stride=module.stride,
                )
                reduced_shallow_model.add_module("shallow_final_conv",
                                                 new_conv_layer)
                break
            reduced_shallow_model.add_module(name, module)

        to_dense_prediction_model(reduced_deep_model)
        to_dense_prediction_model(reduced_shallow_model)
        self.reduced_deep_model = reduced_deep_model
        self.reduced_shallow_model = reduced_shallow_model
        self.final_conv = nn.Conv2d(100,
                                    n_classes,
                                    kernel_size=(1, 1),
                                    stride=1)
Example #2
0
def setup_exp(
    train_folder,
    n_recordings,
    n_chans,
    model_name,
    n_start_chans,
    n_chan_factor,
    input_time_length,
    final_conv_length,
    model_constraint,
    stride_before_pool,
    init_lr,
    batch_size,
    max_epochs,
    cuda,
    num_workers,
    task,
    weight_decay,
    n_folds,
    shuffle_folds,
    lazy_loading,
    eval_folder,
    result_folder,
    run_on_normals,
    run_on_abnormals,
    seed,
    l2_decay,
    gradient_clip,
):
    info_msg = "using {}, {}".format(
        os.environ["SLURM_JOB_PARTITION"],
        os.environ["SLURMD_NODENAME"],
    )
    info_msg += ", gpu {}".format(os.environ["CUDA_VISIBLE_DEVICES"])
    logging.info(info_msg)

    logging.info("Targets for this task: <{}>".format(task))

    import torch.backends.cudnn as cudnn
    cudnn.benchmark = True

    if task == "age":
        loss_function = mse_loss_on_mean
        remember_best_column = "valid_rmse"
        n_classes = 1
    else:
        loss_function = nll_loss_on_mean
        remember_best_column = "valid_misclass"
        n_classes = 2

    if model_constraint is not None:
        assert model_constraint == 'defaultnorm'
        model_constraint = MaxNormDefaultConstraint()

    stop_criterion = MaxEpochs(max_epochs)

    set_random_seeds(seed=seed, cuda=cuda)
    if model_name == 'shallow':
        model = ShallowFBCSPNet(
            in_chans=n_chans,
            n_classes=n_classes,
            n_filters_time=n_start_chans,
            n_filters_spat=n_start_chans,
            input_time_length=input_time_length,
            final_conv_length=final_conv_length).create_network()
    elif model_name == 'deep':
        model = Deep4Net(
            n_chans,
            n_classes,
            n_filters_time=n_start_chans,
            n_filters_spat=n_start_chans,
            input_time_length=input_time_length,
            n_filters_2=int(n_start_chans * n_chan_factor),
            n_filters_3=int(n_start_chans * (n_chan_factor**2.0)),
            n_filters_4=int(n_start_chans * (n_chan_factor**3.0)),
            final_conv_length=final_conv_length,
            stride_before_pool=stride_before_pool).create_network()
    elif model_name == 'eegnet':
        model = EEGNetv4(n_chans,
                         n_classes,
                         input_time_length=input_time_length,
                         final_conv_length=final_conv_length).create_network()
    elif model_name == "tcn":
        assert task != "age", "what to change to do regression with tcn?!"
        model = TemporalConvNet(input_size=n_chans,
                                output_size=n_classes,
                                context_size=0,
                                num_channels=55,
                                num_levels=5,
                                kernel_size=16,
                                dropout=0.05270154233150525,
                                skip_mode=None,
                                use_context=0,
                                lasso_selection=0.0,
                                rnn_normalization=None)
    else:
        assert False, "unknown model name {:s}".format(model_name)

    if task == "age":
        # remove softmax layer, set n_classes to 1
        model.n_classes = 1
        new_model = nn.Sequential()
        for name, module_ in model.named_children():
            if name == "softmax":
                continue
            new_model.add_module(name, module_)
        model = new_model

    # maybe check if this works and wait / re-try after some time?
    # in case of all cuda devices are busy
    if cuda:
        model.cuda()

    if model_name != "tcn":
        to_dense_prediction_model(model)
    logging.info("Model:\n{:s}".format(str(model)))

    test_input = np_to_var(
        np.ones((2, n_chans, input_time_length, 1), dtype=np.float32))
    if list(model.parameters())[0].is_cuda:
        test_input = test_input.cuda()
    out = model(test_input)
    n_preds_per_input = out.cpu().data.numpy().shape[2]

    if eval_folder is None:
        logging.info("will do validation")
        if lazy_loading:
            logging.info(
                "using lazy loading to load {} recs".format(n_recordings))
            dataset = TuhLazy(train_folder,
                              target=task,
                              n_recordings=n_recordings)
        else:
            logging.info("using traditional loading to load {} recs".format(
                n_recordings))
            dataset = Tuh(train_folder, n_recordings=n_recordings, target=task)

        assert not (run_on_normals and run_on_abnormals), (
            "decide whether to run on normal or abnormal subjects")
        # only run on normal subjects
        if run_on_normals:
            ids = [
                i for i in range(len(dataset)) if dataset.pathologicals[i] == 0
            ]  # 0 is non-pathological
            dataset = TuhSubset(dataset, ids)
            logging.info("only using {} normal subjects".format(len(dataset)))
        if run_on_abnormals:
            ids = [
                i for i in range(len(dataset)) if dataset.pathologicals[i] == 1
            ]  # 1 is pathological
            dataset = TuhSubset(dataset, ids)
            logging.info("only using {} abnormal subjects".format(
                len(dataset)))

        indices = np.arange(len(dataset))
        kf = KFold(n_splits=n_folds, shuffle=shuffle_folds)
        for i, (train_ind, test_ind) in enumerate(kf.split(indices)):
            assert len(np.intersect1d(
                train_ind, test_ind)) == 0, ("train and test set overlap!")

            # seed is in range of number of folds and was set by submit script
            if i == seed:
                break

        if lazy_loading:
            test_subset = TuhLazySubset(dataset, test_ind)
            train_subset = TuhLazySubset(dataset, train_ind)
        else:
            test_subset = TuhSubset(dataset, test_ind)
            train_subset = TuhSubset(dataset, train_ind)
    else:
        logging.info("will do final evaluation")
        if lazy_loading:
            train_subset = TuhLazy(train_folder, target=task)
            test_subset = TuhLazy(eval_folder, target=task)
        else:
            train_subset = Tuh(train_folder, target=task)
            test_subset = Tuh(eval_folder, target=task)

        # remove rec:
        # train/abnormal/01_tcp_ar/081/00008184/s001_2011_09_21/00008184_s001_t001
        # since it contains no crop without outliers (channels A1, A2 broken)
        subjects = [f.split("/")[-3] for f in train_subset.file_paths]
        if "00008184" in subjects:
            bad_id = subjects.index("00008184")
            train_subset = remove_file_from_dataset(
                train_subset,
                file_id=bad_id,
                file=("train/abnormal/01_tcp_ar/081/00008184/s001_2011_09_21/"
                      "00008184_s001_t001"))
        subjects = [f.split("/")[-3] for f in test_subset.file_paths]
        if "00008184" in subjects:
            bad_id = subjects.index("00008184")
            test_subset = remove_file_from_dataset(
                test_subset,
                file_id=bad_id,
                file=("train/abnormal/01_tcp_ar/081/00008184/s001_2011_09_21/"
                      "00008184_s001_t001"))

    if task == "age":
        # standardize ages based on train set
        y_train = train_subset.y
        y_train_mean = np.mean(y_train)
        y_train_std = np.std(y_train)
        train_subset.y = (y_train - y_train_mean) / y_train_std
        y_test = test_subset.y
        test_subset.y = (y_test - y_train_mean) / y_train_std

    if lazy_loading:
        iterator = LazyCropsFromTrialsIterator(
            input_time_length,
            n_preds_per_input,
            batch_size,
            seed=seed,
            num_workers=num_workers,
            reset_rng_after_each_batch=False,
            check_preds_smaller_trial_len=False)  # True!
    else:
        iterator = CropsFromTrialsIterator(batch_size, input_time_length,
                                           n_preds_per_input, seed)

    monitors = []
    monitors.append(LossMonitor())
    monitors.append(RAMMonitor())
    monitors.append(RuntimeMonitor())
    if task == "age":
        monitors.append(
            RMSEMonitor(input_time_length,
                        n_preds_per_input,
                        mean=y_train_mean,
                        std=y_train_std))
    else:
        monitors.append(
            CroppedDiagnosisMonitor(input_time_length, n_preds_per_input))
        monitors.append(LazyMisclassMonitor(col_suffix='sample_misclass'))

    if lazy_loading:
        n_updates_per_epoch = len(
            iterator.get_batches(train_subset, shuffle=False))
    else:
        n_updates_per_epoch = sum(
            [1 for _ in iterator.get_batches(train_subset, shuffle=False)])
    n_updates_per_period = n_updates_per_epoch * max_epochs
    logging.info("there are {} updates per epoch".format(n_updates_per_epoch))

    if model_name == "tcn":
        adamw = ExtendedAdam(model.parameters(),
                             lr=init_lr,
                             weight_decay=weight_decay,
                             l2_decay=l2_decay,
                             gradient_clip=gradient_clip)
    else:
        adamw = AdamWWithTracking(model.parameters(),
                                  init_lr,
                                  weight_decay=weight_decay)

    scheduler = CosineAnnealing(n_updates_per_period)
    optimizer = ScheduledOptimizer(scheduler,
                                   adamw,
                                   schedule_weight_decay=True)

    exp = Experiment(model=model,
                     train_set=train_subset,
                     valid_set=None,
                     test_set=test_subset,
                     iterator=iterator,
                     loss_function=loss_function,
                     optimizer=optimizer,
                     model_constraint=model_constraint,
                     monitors=monitors,
                     stop_criterion=stop_criterion,
                     remember_best_column=remember_best_column,
                     run_after_early_stop=False,
                     batch_modifier=None,
                     cuda=cuda,
                     do_early_stop=False,
                     reset_after_second_run=False)
    return exp