コード例 #1
0
def run_exp(max_recording_mins, n_recordings, sec_to_cut,
            duration_recording_mins, max_abs_val, max_min_threshold,
            max_min_expected, shrink_val, max_min_remove, batch_set_zero_val,
            batch_set_zero_test, sampling_freq, low_cut_hz, high_cut_hz,
            exp_demean, exp_standardize, moving_demean, moving_standardize,
            channel_demean, channel_standardize, divisor, n_folds, i_test_fold,
            input_time_length, final_conv_length, pool_stride, n_blocks_to_add,
            sigmoid, model_constraint, batch_size, max_epochs,
            only_return_exp):
    cuda = True

    preproc_functions = []
    preproc_functions.append(lambda data, fs: (
        data[:, int(sec_to_cut * fs):-int(sec_to_cut * fs)], fs))
    preproc_functions.append(lambda data, fs: (data[:, :int(
        duration_recording_mins * 60 * fs)], fs))
    if max_abs_val is not None:
        preproc_functions.append(
            lambda data, fs: (np.clip(data, -max_abs_val, max_abs_val), fs))
    if max_min_threshold is not None:
        preproc_functions.append(lambda data, fs: (clean_jumps(
            data, 200, max_min_threshold, max_min_expected, cuda), fs))
    if max_min_remove is not None:
        window_len = 200
        preproc_functions.append(lambda data, fs: (set_jumps_to_zero(
            data,
            window_len=window_len,
            threshold=max_min_remove,
            cuda=cuda,
            clip_min_max_to_zero=True), fs))

    if shrink_val is not None:
        preproc_functions.append(lambda data, fs: (shrink_spikes(
            data,
            shrink_val,
            1,
            9,
        ), fs))

    preproc_functions.append(lambda data, fs: (resampy.resample(
        data, fs, sampling_freq, axis=1, filter='kaiser_fast'), sampling_freq))
    preproc_functions.append(lambda data, fs: (bandpass_cnt(
        data, low_cut_hz, high_cut_hz, fs, filt_order=4, axis=1), fs))

    if exp_demean:
        preproc_functions.append(lambda data, fs: (exponential_running_demean(
            data.T, factor_new=0.001, init_block_size=100).T, fs))
    if exp_standardize:
        preproc_functions.append(
            lambda data, fs: (exponential_running_standardize(
                data.T, factor_new=0.001, init_block_size=100).T, fs))
    if moving_demean:
        preproc_functions.append(lambda data, fs: (padded_moving_demean(
            data, axis=1, n_window=201), fs))
    if moving_standardize:
        preproc_functions.append(lambda data, fs: (padded_moving_standardize(
            data, axis=1, n_window=201), fs))
    if channel_demean:
        preproc_functions.append(lambda data, fs: (demean(data, axis=1), fs))
    if channel_standardize:
        preproc_functions.append(lambda data, fs:
                                 (standardize(data, axis=1), fs))
    if divisor is not None:
        preproc_functions.append(lambda data, fs: (data / divisor, fs))

    dataset = DiagnosisSet(n_recordings=n_recordings,
                           max_recording_mins=max_recording_mins,
                           preproc_functions=preproc_functions)
    if not only_return_exp:
        X, y = dataset.load()

    splitter = Splitter(
        n_folds,
        i_test_fold,
    )
    if not only_return_exp:
        train_set, valid_set, test_set = splitter.split(X, y)
        del X, y  # shouldn't be necessary, but just to make sure
    else:
        train_set = None
        valid_set = None
        test_set = None

    set_random_seeds(seed=20170629, cuda=cuda)
    if sigmoid:
        n_classes = 1
    else:
        n_classes = 2
    in_chans = 21

    net = Deep4Net(
        in_chans=in_chans,
        n_classes=n_classes,
        input_time_length=input_time_length,
        final_conv_length=final_conv_length,
        pool_time_length=pool_stride,
        pool_time_stride=pool_stride,
        n_filters_2=50,
        n_filters_3=80,
        n_filters_4=120,
    )
    model = net_with_more_layers(net, n_blocks_to_add, nn.MaxPool2d)
    if sigmoid:
        model = to_linear_plus_minus_net(model)
    optimizer = optim.Adam(model.parameters())
    to_dense_prediction_model(model)
    log.info("Model:\n{:s}".format(str(model)))
    if cuda:
        model.cuda()
    # determine output size
    test_input = np_to_var(
        np.ones((2, in_chans, input_time_length, 1), dtype=np.float32))
    if cuda:
        test_input = test_input.cuda()
    out = model(test_input)
    n_preds_per_input = out.cpu().data.numpy().shape[2]
    log.info("{:d} predictions per input/trial".format(n_preds_per_input))
    iterator = CropsFromTrialsIterator(batch_size=batch_size,
                                       input_time_length=input_time_length,
                                       n_preds_per_input=n_preds_per_input)
    if sigmoid:
        loss_function = lambda preds, targets: binary_cross_entropy_with_logits(
            th.mean(preds, dim=2)[:, 1, 0], targets.type_as(preds))
    else:
        loss_function = lambda preds, targets: F.nll_loss(
            th.mean(preds, dim=2)[:, :, 0], targets)

    if model_constraint is not None:
        model_constraint = MaxNormDefaultConstraint()
    monitors = [
        LossMonitor(),
        MisclassMonitor(col_suffix='sample_misclass'),
        CroppedTrialMisclassMonitor(input_time_length),
        RuntimeMonitor(),
    ]
    stop_criterion = MaxEpochs(max_epochs)
    batch_modifier = None
    if batch_set_zero_val is not None:
        batch_modifier = RemoveMinMaxDiff(batch_set_zero_val,
                                          clip_max_abs=True,
                                          set_zero=True)
    if (batch_set_zero_val is not None) and (batch_set_zero_test == True):
        iterator = ModifiedIterator(
            iterator,
            batch_modifier,
        )
        batch_modifier = None
    exp = Experiment(model,
                     train_set,
                     valid_set,
                     test_set,
                     iterator,
                     loss_function,
                     optimizer,
                     model_constraint,
                     monitors,
                     stop_criterion,
                     remember_best_column='valid_misclass',
                     run_after_early_stop=True,
                     batch_modifier=batch_modifier,
                     cuda=cuda)
    if not only_return_exp:
        exp.run()
    else:
        exp.dataset = dataset
        exp.splitter = splitter

    return exp
コード例 #2
0
def run_exp(max_recording_mins,
            n_recordings,
            sec_to_cut_at_start,
            sec_to_cut_at_end,
            duration_recording_mins,
            max_abs_val,
            clip_before_resample,
            sampling_freq,
            divisor,
            n_folds,
            i_test_fold,
            shuffle,
            merge_train_valid,
            model,
            input_time_length,
            optimizer,
            learning_rate,
            weight_decay,
            scheduler,
            model_constraint,
            batch_size,
            max_epochs,
            only_return_exp,
            time_cut_off_sec,
            start_time,
            test_on_eval,
            test_recording_mins,
            sensor_types,
            log_dir,
            np_th_seed,
            cuda=True):
    import torch.backends.cudnn as cudnn
    cudnn.benchmark = True
    if optimizer == 'adam':
        assert merge_train_valid == False
    else:
        assert optimizer == 'adamw'
        assert merge_train_valid == True

    preproc_functions = create_preproc_functions(
        sec_to_cut_at_start=sec_to_cut_at_start,
        sec_to_cut_at_end=sec_to_cut_at_end,
        duration_recording_mins=duration_recording_mins,
        max_abs_val=max_abs_val,
        clip_before_resample=clip_before_resample,
        sampling_freq=sampling_freq,
        divisor=divisor)

    dataset = DiagnosisSet(n_recordings=n_recordings,
                           max_recording_mins=max_recording_mins,
                           preproc_functions=preproc_functions,
                           train_or_eval='train',
                           sensor_types=sensor_types)

    if test_on_eval:
        if test_recording_mins is None:
            test_recording_mins = duration_recording_mins

        test_preproc_functions = create_preproc_functions(
            sec_to_cut_at_start=sec_to_cut_at_start,
            sec_to_cut_at_end=sec_to_cut_at_end,
            duration_recording_mins=test_recording_mins,
            max_abs_val=max_abs_val,
            clip_before_resample=clip_before_resample,
            sampling_freq=sampling_freq,
            divisor=divisor)
        test_dataset = DiagnosisSet(n_recordings=n_recordings,
                                    max_recording_mins=None,
                                    preproc_functions=test_preproc_functions,
                                    train_or_eval='eval',
                                    sensor_types=sensor_types)
    if not only_return_exp:
        X, y = dataset.load()
        max_shape = np.max([list(x.shape) for x in X], axis=0)
        assert max_shape[1] == int(duration_recording_mins * sampling_freq *
                                   60)
        if test_on_eval:
            test_X, test_y = test_dataset.load()
            max_shape = np.max([list(x.shape) for x in test_X], axis=0)
            assert max_shape[1] == int(test_recording_mins * sampling_freq *
                                       60)
    if not test_on_eval:
        splitter = TrainValidTestSplitter(n_folds,
                                          i_test_fold,
                                          shuffle=shuffle)
    else:
        splitter = TrainValidSplitter(n_folds,
                                      i_valid_fold=i_test_fold,
                                      shuffle=shuffle)
    if not only_return_exp:
        if not test_on_eval:
            train_set, valid_set, test_set = splitter.split(X, y)
        else:

            train_set, valid_set = splitter.split(X, y)
            test_set = SignalAndTarget(test_X, test_y)
            del test_X, test_y
        del X, y  # shouldn't be necessary, but just to make sure
        if merge_train_valid:
            train_set = concatenate_sets([train_set, valid_set])
            # just reduce valid for faster computations
            valid_set.X = valid_set.X[:8]
            valid_set.y = valid_set.y[:8]
            # np.save('/data/schirrmr/schirrmr/auto-diag/lukasrepr/compare/mne-0-16-2/train_X.npy', train_set.X)
            # np.save('/data/schirrmr/schirrmr/auto-diag/lukasrepr/compare/mne-0-16-2/train_y.npy', train_set.y)
            # np.save('/data/schirrmr/schirrmr/auto-diag/lukasrepr/compare/mne-0-16-2/valid_X.npy', valid_set.X)
            # np.save('/data/schirrmr/schirrmr/auto-diag/lukasrepr/compare/mne-0-16-2/valid_y.npy', valid_set.y)
            # np.save('/data/schirrmr/schirrmr/auto-diag/lukasrepr/compare/mne-0-16-2/test_X.npy', test_set.X)
            # np.save('/data/schirrmr/schirrmr/auto-diag/lukasrepr/compare/mne-0-16-2/test_y.npy', test_set.y)
    else:
        train_set = None
        valid_set = None
        test_set = None

    log.info("Model:\n{:s}".format(str(model)))
    if cuda:
        model.cuda()
    model.eval()
    in_chans = 21
    # determine output size
    test_input = np_to_var(
        np.ones((2, in_chans, input_time_length, 1), dtype=np.float32))
    if cuda:
        test_input = test_input.cuda()
    out = model(test_input)
    n_preds_per_input = out.cpu().data.numpy().shape[2]
    log.info("{:d} predictions per input/trial".format(n_preds_per_input))
    iterator = CropsFromTrialsIterator(batch_size=batch_size,
                                       input_time_length=input_time_length,
                                       n_preds_per_input=n_preds_per_input,
                                       seed=np_th_seed)
    assert optimizer in ['adam', 'adamw'], ("Expect optimizer to be either "
                                            "adam or adamw")
    schedule_weight_decay = optimizer == 'adamw'
    if optimizer == 'adam':
        optim_class = optim.Adam
        assert schedule_weight_decay == False
        assert merge_train_valid == False
    else:
        optim_class = AdamW
        assert schedule_weight_decay == True
        assert merge_train_valid == True

    optimizer = optim_class(model.parameters(),
                            lr=learning_rate,
                            weight_decay=weight_decay)
    if scheduler is not None:
        assert scheduler == 'cosine'
        n_updates_per_epoch = sum(
            [1 for _ in iterator.get_batches(train_set, shuffle=True)])
        # Adapt if you have a different number of epochs
        n_updates_per_period = n_updates_per_epoch * max_epochs
        scheduler = CosineAnnealing(n_updates_per_period)
        optimizer = ScheduledOptimizer(
            scheduler, optimizer, schedule_weight_decay=schedule_weight_decay)
    loss_function = nll_loss_on_mean

    if model_constraint is not None:
        assert model_constraint == 'defaultnorm'
        model_constraint = MaxNormDefaultConstraint()
    monitors = [
        LossMonitor(),
        MisclassMonitor(col_suffix='sample_misclass'),
        CroppedDiagnosisMonitor(input_time_length, n_preds_per_input),
        RuntimeMonitor(),
    ]

    stop_criterion = MaxEpochs(max_epochs)
    loggers = [Printer(), TensorboardWriter(log_dir)]
    batch_modifier = None
    exp = Experiment(model,
                     train_set,
                     valid_set,
                     test_set,
                     iterator,
                     loss_function,
                     optimizer,
                     model_constraint,
                     monitors,
                     stop_criterion,
                     remember_best_column='valid_misclass',
                     run_after_early_stop=True,
                     batch_modifier=batch_modifier,
                     cuda=cuda,
                     loggers=loggers)

    if not only_return_exp:
        # Until first stop
        exp.setup_training()
        exp.monitor_epoch(exp.datasets)
        exp.log_epoch()
        exp.rememberer.remember_epoch(exp.epochs_df, exp.model, exp.optimizer)

        exp.iterator.reset_rng()
        while not exp.stop_criterion.should_stop(exp.epochs_df):
            if (time.time() - start_time) > time_cut_off_sec:
                log.info(
                    "Ran out of time after {:.2f} sec.".format(time.time() -
                                                               start_time))
                return exp
            log.info("Still in time after {:.2f} sec.".format(time.time() -
                                                              start_time))
            exp.run_one_epoch(exp.datasets, remember_best=True)
        if (time.time() - start_time) > time_cut_off_sec:
            log.info("Ran out of time after {:.2f} sec.".format(time.time() -
                                                                start_time))
            return exp
        if not merge_train_valid:
            exp.setup_after_stop_training()
            # Run until second stop
            datasets = exp.datasets
            datasets['train'] = concatenate_sets(
                [datasets['train'], datasets['valid']])
            exp.monitor_epoch(datasets)
            exp.log_epoch()

            exp.iterator.reset_rng()
            while not exp.stop_criterion.should_stop(exp.epochs_df):
                if (time.time() - start_time) > time_cut_off_sec:
                    log.info("Ran out of time after {:.2f} sec.".format(
                        time.time() - start_time))
                    return exp
                log.info("Still in time after {:.2f} sec.".format(time.time() -
                                                                  start_time))
                exp.run_one_epoch(datasets, remember_best=False)

    else:
        exp.dataset = dataset
        exp.splitter = splitter
    if test_on_eval:
        exp.test_dataset = test_dataset

    return exp
コード例 #3
0
def run_exp(max_recording_mins, n_recordings, sec_to_cut,
            duration_recording_mins, max_abs_val, max_min_threshold,
            max_min_expected, shrink_val, max_min_remove, batch_set_zero_val,
            batch_set_zero_test, sampling_freq, low_cut_hz, high_cut_hz,
            exp_demean, exp_standardize, moving_demean, moving_standardize,
            channel_demean, channel_standardize, divisor, n_folds, i_test_fold,
            model_name, input_time_length, final_conv_length, batch_size,
            max_epochs, only_return_exp):
    cuda = True

    preproc_functions = []
    preproc_functions.append(lambda data, fs: (
        data[:, int(sec_to_cut * fs):-int(sec_to_cut * fs)], fs))
    preproc_functions.append(lambda data, fs: (data[:, :int(
        duration_recording_mins * 60 * fs)], fs))
    if max_abs_val is not None:
        preproc_functions.append(
            lambda data, fs: (np.clip(data, -max_abs_val, max_abs_val), fs))
    if max_min_threshold is not None:
        preproc_functions.append(lambda data, fs: (clean_jumps(
            data, 200, max_min_threshold, max_min_expected, cuda), fs))
    if max_min_remove is not None:
        window_len = 200
        preproc_functions.append(lambda data, fs: (set_jumps_to_zero(
            data,
            window_len=window_len,
            threshold=max_min_remove,
            cuda=cuda,
            clip_min_max_to_zero=True), fs))

    if shrink_val is not None:
        preproc_functions.append(lambda data, fs: (shrink_spikes(
            data,
            shrink_val,
            1,
            9,
        ), fs))

    preproc_functions.append(lambda data, fs: (resampy.resample(
        data, fs, sampling_freq, axis=1, filter='kaiser_fast'), sampling_freq))
    preproc_functions.append(lambda data, fs: (bandpass_cnt(
        data, low_cut_hz, high_cut_hz, fs, filt_order=4, axis=1), fs))

    if exp_demean:
        preproc_functions.append(lambda data, fs: (exponential_running_demean(
            data.T, factor_new=0.001, init_block_size=100).T, fs))
    if exp_standardize:
        preproc_functions.append(
            lambda data, fs: (exponential_running_standardize(
                data.T, factor_new=0.001, init_block_size=100).T, fs))
    if moving_demean:
        preproc_functions.append(lambda data, fs: (padded_moving_demean(
            data, axis=1, n_window=201), fs))
    if moving_standardize:
        preproc_functions.append(lambda data, fs: (padded_moving_standardize(
            data, axis=1, n_window=201), fs))
    if channel_demean:
        preproc_functions.append(lambda data, fs: (demean(data, axis=1), fs))
    if channel_standardize:
        preproc_functions.append(lambda data, fs:
                                 (standardize(data, axis=1), fs))
    if divisor is not None:
        preproc_functions.append(lambda data, fs: (data / divisor, fs))

    all_file_names, labels = get_all_sorted_file_names_and_labels()
    lengths = np.load(
        '/home/schirrmr/code/auto-diagnosis/sorted-recording-lengths.npy')
    mask = lengths < max_recording_mins * 60
    cleaned_file_names = np.array(all_file_names)[mask]
    cleaned_labels = labels[mask]

    diffs_per_rec = np.load(
        '/home/schirrmr/code/auto-diagnosis/diffs_per_recording.npy')

    def create_set(inds):
        X = []
        for i in inds:
            log.info("Load {:s}".format(cleaned_file_names[i]))
            x = load_data(cleaned_file_names[i], preproc_functions)
            X.append(x)
        y = cleaned_labels[inds].astype(np.int64)
        return SignalAndTarget(X, y)

    if not only_return_exp:
        folds = get_balanced_batches(n_recordings,
                                     None,
                                     False,
                                     n_batches=n_folds)
        test_inds = folds[i_test_fold]
        valid_inds = folds[i_test_fold - 1]
        all_inds = list(range(n_recordings))
        train_inds = np.setdiff1d(all_inds, np.union1d(test_inds, valid_inds))

        rec_nr_sorted_by_diff = np.argsort(diffs_per_rec)[::-1]
        train_inds = rec_nr_sorted_by_diff[train_inds]
        valid_inds = rec_nr_sorted_by_diff[valid_inds]
        test_inds = rec_nr_sorted_by_diff[test_inds]

        train_set = create_set(train_inds)
        valid_set = create_set(valid_inds)
        test_set = create_set(test_inds)
    else:
        train_set = None
        valid_set = None
        test_set = None

    set_random_seeds(seed=20170629, cuda=cuda)
    # This will determine how many crops are processed in parallel
    n_classes = 2
    in_chans = 21
    if model_name == 'shallow':
        model = ShallowFBCSPNet(
            in_chans=in_chans,
            n_classes=n_classes,
            input_time_length=input_time_length,
            final_conv_length=final_conv_length).create_network()
    elif model_name == 'deep':
        model = Deep4Net(in_chans,
                         n_classes,
                         input_time_length=input_time_length,
                         final_conv_length=final_conv_length).create_network()

    optimizer = optim.Adam(model.parameters())
    to_dense_prediction_model(model)
    log.info("Model:\n{:s}".format(str(model)))
    if cuda:
        model.cuda()
    # determine output size
    test_input = np_to_var(
        np.ones((2, in_chans, input_time_length, 1), dtype=np.float32))
    if cuda:
        test_input = test_input.cuda()
    out = model(test_input)
    n_preds_per_input = out.cpu().data.numpy().shape[2]
    log.info("{:d} predictions per input/trial".format(n_preds_per_input))
    iterator = CropsFromTrialsIterator(batch_size=batch_size,
                                       input_time_length=input_time_length,
                                       n_preds_per_input=n_preds_per_input)
    loss_function = lambda preds, targets: F.nll_loss(
        th.mean(preds, dim=2)[:, :, 0], targets)
    model_constraint = None
    monitors = [
        LossMonitor(),
        MisclassMonitor(col_suffix='sample_misclass'),
        CroppedTrialMisclassMonitor(input_time_length),
        RuntimeMonitor(),
    ]
    stop_criterion = MaxEpochs(max_epochs)
    batch_modifier = None
    if batch_set_zero_val is not None:
        batch_modifier = RemoveMinMaxDiff(batch_set_zero_val,
                                          clip_max_abs=True,
                                          set_zero=True)
    if (batch_set_zero_val is not None) and (batch_set_zero_test == True):
        iterator = ModifiedIterator(
            iterator,
            batch_modifier,
        )
        batch_modifier = None
    exp = Experiment(model,
                     train_set,
                     valid_set,
                     test_set,
                     iterator,
                     loss_function,
                     optimizer,
                     model_constraint,
                     monitors,
                     stop_criterion,
                     remember_best_column='valid_misclass',
                     run_after_early_stop=True,
                     batch_modifier=batch_modifier,
                     cuda=cuda)
    if not only_return_exp:
        exp.run()
    else:
        exp.dataset = None
        exp.splitter = None

    return exp