Ejemplo n.º 1
0
def load_train_valid_tuh(n_subjects, n_seconds, ids_to_load):
    path = '/home/schirrmr/data/preproced-tuh/all-sensors-32-hz/'
    log.info("Load concat dataset...")
    dataset = load_concat_dataset(path, preload=False, ids_to_load=ids_to_load)
    whole_train_set = dataset.split('session')['train']
    n_max_minutes = int(np.ceil(n_seconds / 60) + 2)
    sfreq = whole_train_set.datasets[0].raw.info['sfreq']
    log.info("Preprocess concat dataset...")
    preprocess(whole_train_set, [
        MNEPreproc('crop', tmin=0, tmax=n_max_minutes * 60, include_tmax=True),
        NumpyPreproc(fn=lambda x: np.clip(x, -80, 80)),
        NumpyPreproc(fn=lambda x: x / 3),
        NumpyPreproc(fn=exponential_moving_demean,
                     init_block_size=int(sfreq * 10),
                     factor_new=1 / (sfreq * 5)),
    ])
    subject_datasets = whole_train_set.split('subject')

    n_split = int(np.round(n_subjects * 0.75))
    keys = list(subject_datasets.keys())
    train_sets = [
        d for i in range(n_split) for d in subject_datasets[keys[i]].datasets
    ]
    train_set = BaseConcatDataset(train_sets)
    valid_sets = [
        d for i in range(n_split, n_subjects)
        for d in subject_datasets[keys[i]].datasets
    ]
    valid_set = BaseConcatDataset(valid_sets)

    train_set = create_fixed_length_windows(
        train_set,
        start_offset_samples=60 * 32,
        stop_offset_samples=60 * 32 + 32 * n_seconds,
        preload=True,
        window_size_samples=128,
        window_stride_samples=64,
        drop_last_window=True,
    )

    valid_set = create_fixed_length_windows(
        valid_set,
        start_offset_samples=60 * 32,
        stop_offset_samples=60 * 32 + 32 * n_seconds,
        preload=True,
        window_size_samples=128,
        window_stride_samples=64,
        drop_last_window=True,
    )
    return train_set, valid_set
Ejemplo n.º 2
0
def test_windows_from_events_different_events(tmpdir_factory):
    description_expected = 5 * ['T0', 'T1'] + 4 * ['T2', 'T3'] + 2 * ['T1']
    raw = _get_raw(tmpdir_factory, description_expected[:10])
    base_ds = BaseDataset(raw, description=pd.Series({'file_id': 1}))

    raw_1 = _get_raw(tmpdir_factory, description_expected[10:])
    base_ds_1 = BaseDataset(raw_1, description=pd.Series({'file_id': 2}))
    concat_ds = BaseConcatDataset([base_ds, base_ds_1])

    windows = create_windows_from_events(
        concat_ds=concat_ds, trial_start_offset_samples=0,
        trial_stop_offset_samples=0, window_size_samples=100,
        window_stride_samples=100, drop_last_window=False)
    description = []
    events = []
    for ds in windows.datasets:
        description += ds.windows.metadata['target'].to_list()
        events += ds.windows.events[:, 0].tolist()

    assert len(description) == 20
    np.testing.assert_array_equal(description,
                                  5 * [0, 1] + 4 * [2, 3] + 2 * [1])
    np.testing.assert_array_equal(
        np.concatenate(
            [raw.time_as_index(raw.annotations.onset, use_rounding=True),
             raw_1.time_as_index(raw.annotations.onset, use_rounding=True)]),
        events)
Ejemplo n.º 3
0
def preprocess(concat_ds,
               preprocessors,
               save_dir=None,
               overwrite=False,
               n_jobs=None):
    """Apply preprocessors to a concat dataset.

    Parameters
    ----------
    concat_ds: BaseConcatDataset
        A concat of BaseDataset or WindowsDataset datasets to be preprocessed.
    preprocessors: list(Preprocessor)
        List of Preprocessor objects to apply to the dataset.
    save_dir : str | None
        If a string, the preprocessed data will be saved under the specified
        directory and the datasets in ``concat_ds`` will be reloaded with
        `preload=False`.
    overwrite : bool
        When `save_dir` is provided, controls whether to delete the old
        subdirectories that will be written to under `save_dir`. If False and
        the corresponding subdirectories already exist, a ``FileExistsError``
        will be raised.
    n_jobs : int | None
        Number of jobs for parallel execution.

    Returns
    -------
    BaseConcatDataset:
        Preprocessed dataset.
    """
    # In case of serialization, make sure directory is available before
    # preprocessing
    if save_dir is not None and not overwrite:
        _check_save_dir_empty(save_dir)

    if not isinstance(preprocessors, Iterable):
        raise ValueError(
            'preprocessors must be a list of Preprocessor objects.')
    for elem in preprocessors:
        assert hasattr(
            elem, 'apply'), ('Preprocessor object needs an `apply` method.')

    list_of_ds = Parallel(n_jobs=n_jobs)(
        delayed(_preprocess)(ds, i, preprocessors, save_dir, overwrite)
        for i, ds in enumerate(concat_ds.datasets))

    if save_dir is not None:  # Reload datasets and replace in concat_ds
        concat_ds_reloaded = load_concat_dataset(save_dir,
                                                 preload=False,
                                                 target_name=None)
        _replace_inplace(concat_ds, concat_ds_reloaded)
    else:
        if n_jobs is None or n_jobs == 1:  # joblib did not make copies, the
            # preprocessing happened in-place
            # Recompute cumulative sizes as transforms might have changed them
            concat_ds.cumulative_sizes = concat_ds.cumsum(concat_ds.datasets)
        else:  # joblib made copies
            _replace_inplace(concat_ds, BaseConcatDataset(list_of_ds))

    return concat_ds
Ejemplo n.º 4
0
def load_concat_dataset(path, preload, ids_to_load=None, target_name=None):
    """Load a stored BaseConcatDataset of BaseDatasets or WindowsDatasets from
    files.

    Parameters
    ----------
    path: str
        Path to the directory of the .fif / -epo.fif and .json files.
    preload: bool
        Whether to preload the data.
    ids_to_load: None | list(int)
        Ids of specific files to load.
    target_name: None or str
        Load specific description column as target. If not given, take saved target name.

    Returns
    -------
    concat_dataset: BaseConcatDataset of BaseDatasets or WindowsDatasets
    """
    # assume we have a single concat dataset to load
    concat_of_raws = os.path.isfile(os.path.join(path, '0-raw.fif'))
    assert not (not concat_of_raws and target_name is not None), (
        'Setting a new target is only supported for raws.')
    concat_of_epochs = os.path.isfile(os.path.join(path, '0-epo.fif'))
    paths = [path]
    # assume we have multiple concat datasets to load
    if not (concat_of_raws or concat_of_epochs):
        concat_of_raws = os.path.isfile(os.path.join(path, '0', '0-raw.fif'))
        concat_of_epochs = os.path.isfile(os.path.join(path, '0', '0-epo.fif'))
        path = os.path.join(path, '*', '')
        paths = glob(path)
        paths = sorted(paths, key=lambda p: int(p.split(os.sep)[-2]))
        if ids_to_load is not None:
            paths = [paths[i] for i in ids_to_load]
        ids_to_load = None
    # if we have neither a single nor multiple datasets, something went wrong
    assert concat_of_raws or concat_of_epochs, (
        f'Expect either raw or epo to exist in {path} or in '
        f'{os.path.join(path, "0")}')

    datasets = []
    for path in paths:
        if concat_of_raws and target_name is None:
            target_file_name = os.path.join(path, 'target_name.json')
            target_name = json.load(open(target_file_name, "r"))['target_name']

        all_signals, description = _load_signals_and_description(
            path=path, preload=preload, raws=concat_of_raws,
            ids_to_load=ids_to_load
        )
        for i_signal, signal in enumerate(all_signals):
            if concat_of_raws:
                datasets.append(
                    BaseDataset(signal, description.iloc[i_signal],
                                target_name=target_name))
            else:
                datasets.append(
                    WindowsDataset(signal, description.iloc[i_signal])
                )
    return BaseConcatDataset(datasets)
Ejemplo n.º 5
0
def lazy_loadable_dataset(tmpdir_factory):
    """Make a dataset of fif files that can be loaded lazily.
    """
    raw = _get_raw(tmpdir_factory)
    base_ds = BaseDataset(raw, description=pd.Series({'file_id': 1}))
    concat_ds = BaseConcatDataset([base_ds])

    return concat_ds
Ejemplo n.º 6
0
def concat_ds_targets():
    raws, description = fetch_data_with_moabb(
        dataset_name="BNCI2014001", subject_ids=4)
    events, _ = mne.events_from_annotations(raws[0])
    targets = events[:, -1] - 1
    ds = BaseDataset(raws[0], description.iloc[0])
    concat_ds = BaseConcatDataset([ds])
    return concat_ds, targets
Ejemplo n.º 7
0
def _preprocess(ds, ds_index, preprocessors, save_dir=None, overwrite=False):
    """Apply preprocessor(s) to Raw or Epochs object.

    Parameters
    ----------
    ds: BaseDataset | WindowsDataset
        Dataset object to preprocess.
    ds_index : int
        Index of the BaseDataset in its BaseConcatDataset. Ignored if save_dir
        is None.
    preprocessors: list(Preprocessor)
        List of preprocessors to apply to the dataset.
    save_dir : str | None
        If provided, save the preprocessed BaseDataset in the
        specified directory.
    overwrite : bool
        If True, overwrite existing file with the same name.
    """
    def _preprocess_raw_or_epochs(raw_or_epochs, preprocessors):
        for preproc in preprocessors:
            preproc.apply(raw_or_epochs)

    if hasattr(ds, 'raw'):
        _preprocess_raw_or_epochs(ds.raw, preprocessors)
    elif hasattr(ds, 'windows'):
        _preprocess_raw_or_epochs(ds.windows, preprocessors)
    else:
        raise ValueError(
            'Can only preprocess concatenation of BaseDataset or '
            'WindowsDataset, with either a `raw` or `windows` attribute.')

    # Store preprocessing keyword arguments in the dataset
    _set_preproc_kwargs(ds, preprocessors)

    if save_dir is not None:
        concat_ds = BaseConcatDataset([ds])
        concat_ds.save(save_dir, overwrite=overwrite, offset=ds_index)
    else:
        return ds
Ejemplo n.º 8
0
def test_windows_from_events_mapping_filter(tmpdir_factory):
    raw = _get_raw(tmpdir_factory, 5 * ['T0', 'T1'])
    base_ds = BaseDataset(raw, description=pd.Series({'file_id': 1}))
    concat_ds = BaseConcatDataset([base_ds])

    windows = create_windows_from_events(
        concat_ds=concat_ds, trial_start_offset_samples=0,
        trial_stop_offset_samples=0, window_size_samples=100,
        window_stride_samples=100, drop_last_window=False, mapping={'T1': 0})
    description = windows.datasets[0].windows.metadata['target'].to_list()

    assert len(description) == 5
    np.testing.assert_array_equal(description, np.zeros(5))
    # dataset should contain only 'T1' events
    np.testing.assert_array_equal(
        (raw.time_as_index(raw.annotations.onset[1::2], use_rounding=True)),
        windows.datasets[0].windows.events[:, 0])
Ejemplo n.º 9
0
def test_windows_from_events_n_jobs(lazy_loadable_dataset):
    longer_dataset = BaseConcatDataset([lazy_loadable_dataset.datasets[0]] * 8)
    windows = [create_windows_from_events(
        concat_ds=longer_dataset, trial_start_offset_samples=0,
        trial_stop_offset_samples=0, window_size_samples=100,
        window_stride_samples=100, drop_last_window=False, preload=True,
        n_jobs=n_jobs) for n_jobs in [1, 2]]

    assert windows[0].description.equals(windows[1].description)
    for ds1, ds2 in zip(windows[0].datasets, windows[1].datasets):
        # assert ds1.windows == ds2.windows  # Runs locally, fails in CI
        assert np.allclose(ds1.windows.get_data(), ds2.windows.get_data())
        assert pd.Series(ds1.windows.info).to_json() == \
               pd.Series(ds2.windows.info).to_json()
        assert ds1.description.equals(ds2.description)
        assert np.array_equal(ds1.y, ds2.y)
        assert np.array_equal(ds1.crop_inds, ds2.crop_inds)
Ejemplo n.º 10
0
def test_fixed_length_windower(start_offset_samples, window_size_samples,
                               window_stride_samples, drop_last_window,
                               mapping):
    rng = np.random.RandomState(42)
    info = mne.create_info(ch_names=['0', '1'], sfreq=50, ch_types='eeg')
    data = rng.randn(2, 1000)
    raw = mne.io.RawArray(data=data, info=info)
    desc = pd.Series({'pathological': True, 'gender': 'M', 'age': 48})
    base_ds = BaseDataset(raw, desc, target_name="age")
    concat_ds = BaseConcatDataset([base_ds])

    if window_size_samples is None:
        window_size_samples = base_ds.raw.n_times
    stop_offset_samples = data.shape[1] - start_offset_samples
    epochs_ds = create_fixed_length_windows(
        concat_ds,
        start_offset_samples=start_offset_samples,
        stop_offset_samples=stop_offset_samples,
        window_size_samples=window_size_samples,
        window_stride_samples=window_stride_samples,
        drop_last_window=drop_last_window,
        mapping=mapping)

    if mapping is not None:
        assert base_ds.target == 48
        assert all(epochs_ds.datasets[0].windows.metadata['target'] == 0)

    epochs_data = epochs_ds.datasets[0].windows.get_data()

    idxs = np.arange(start_offset_samples,
                     stop_offset_samples - window_size_samples + 1,
                     window_stride_samples)
    if not drop_last_window and idxs[
            -1] != stop_offset_samples - window_size_samples:
        idxs = np.append(idxs, stop_offset_samples - window_size_samples)

    assert len(idxs) == epochs_data.shape[0], (
        'Number of epochs different than expected')
    assert window_size_samples == epochs_data.shape[2], (
        'Window size different than expected')
    for j, idx in enumerate(idxs):
        np.testing.assert_allclose(base_ds.raw.get_data()[:, idx:idx +
                                                          window_size_samples],
                                   epochs_data[j, :],
                                   err_msg=f'Epochs different for epoch {j}')
Ejemplo n.º 11
0
def dataset_target_time_series():
    rng = np.random.RandomState(42)
    signal_sfreq = 50
    info = mne.create_info(ch_names=['0', '1', 'target_0', 'target_1'],
                           sfreq=signal_sfreq,
                           ch_types=['eeg', 'eeg', 'misc', 'misc'])
    signal = rng.randn(2, 1000)
    targets = np.full((2, 1000), np.nan)
    targets_sfreq = 10
    targets_stride = int(signal_sfreq / targets_sfreq)
    targets[:, ::targets_stride] = rng.randn(2, int(targets.shape[1] / targets_stride))

    raw = mne.io.RawArray(np.concatenate([signal, targets]), info=info)
    desc = pd.Series({'pathological': True, 'gender': 'M', 'age': 48})
    base_dataset = BaseDataset(raw, desc, target_name=None)
    concat_ds = BaseConcatDataset([base_dataset])
    windows_dataset = create_windows_from_target_channels(
        concat_ds,
        window_size_samples=100,
    )
    return concat_ds, windows_dataset, targets, signal
def exp(subject_id):
    import torch
    test_subj = np.r_[subject_id]
    print('test subj:' + str(test_subj))
    # train_subj = np.setdiff1d(np.r_[1:10], test_subj)
    train_subj = np.setdiff1d(np.r_[1, 3, 7, 8], test_subj)

    tr = []
    val = []
    for ids in train_subj:
        train_size = int(0.99 * len(splitted[ids]))
        test_size = len(splitted[ids]) - train_size
        tr_i, val_i = torch.utils.data.random_split(splitted[ids],
                                                    [train_size, test_size])
        tr.append(tr_i)
        val.append(val_i)

    train_set = torch.utils.data.ConcatDataset(tr)
    valid_set = torch.utils.data.ConcatDataset(val)
    valid_set = BaseConcatDataset([splitted[ids] for ids in test_subj])

    ######################################################################
    # Create model
    # ------------
    #

    ######################################################################
    # Now we create the deep learning model! Braindecode comes with some
    # predefined convolutional neural network architectures for raw
    # time-domain EEG. Here, we use the shallow ConvNet model from `Deep
    # learning with convolutional neural networks for EEG decoding and
    # visualization <https://arxiv.org/abs/1703.05051>`__. These models are
    # pure `PyTorch <https://pytorch.org>`__ deep learning models, therefore
    # to use your own model, it just has to be a normal PyTorch
    # `nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__.
    #

    import torch
    from braindecode.util import set_random_seeds
    from braindecode.models import ShallowFBCSPNet, Deep4Net

    cuda = torch.cuda.is_available(
    )  # check if GPU is available, if True chooses to use it
    device = 'cuda:0' if cuda else 'cpu'
    if cuda:
        torch.backends.cudnn.benchmark = True
    seed = 20200220  # random seed to make results reproducible
    # Set random seed to be able to reproduce results
    set_random_seeds(seed=seed, cuda=cuda)

    n_classes = 3
    # Extract number of chans and time steps from dataset
    n_chans = train_set[0][0].shape[0]
    input_window_samples = train_set[0][0].shape[1]
    #
    # model = ShallowFBCSPNet(
    #     n_chans,
    #     n_classes,
    #     input_window_samples=input_window_samples,
    #     final_conv_length='auto',
    # )

    from mynetworks import Deep4Net_origin, ConvClfNet, FcClfNet

    model = Deep4Net(
        n_chans,
        n_classes,
        input_window_samples=input_window_samples,
        final_conv_length="auto",
    )

    #
    # embedding_net = Deep4Net_origin(4, 22, input_window_samples)
    # model = FcClfNet(embedding_net)
    # #

    print(model)

    # Send model to GPU
    if cuda:
        model.cuda()

    ######################################################################
    # Training
    # --------
    #

    ######################################################################
    # Now we train the network! EEGClassifier is a Braindecode object
    # responsible for managing the training of neural networks. It inherits
    # from skorch.NeuralNetClassifier, so the training logic is the same as in
    # `Skorch <https://skorch.readthedocs.io/en/stable/>`__.
    #

    ######################################################################
    #    **Note**: In this tutorial, we use some default parameters that we
    #    have found to work well for motor decoding, however we strongly
    #    encourage you to perform your own hyperparameter optimization using
    #    cross validation on your training data.
    #

    from skorch.callbacks import LRScheduler
    from skorch.helper import predefined_split

    from braindecode import EEGClassifier
    # # These values we found good for shallow network:
    lr = 0.0625 * 0.01
    weight_decay = 0

    # For deep4 they should be:
    # lr = 1 * 0.01
    # weight_decay = 0.5 * 0.001

    batch_size = 8
    n_epochs = 100

    clf = EEGClassifier(
        model,
        criterion=torch.nn.NLLLoss,
        optimizer=torch.optim.AdamW,
        train_split=predefined_split(
            valid_set),  # using valid_set for validation
        optimizer__lr=lr,
        optimizer__weight_decay=weight_decay,
        batch_size=batch_size,
        callbacks=[
            "accuracy",
            ("lr_scheduler",
             LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),
        ],
        device=device,
    )
    # Model training for a specified number of epochs. `y` is None as it is already supplied
    # in the dataset.
    clf.fit(train_set, y=None, epochs=n_epochs)

    ######################################################################
    # Plot Results
    # ------------
    #

    ######################################################################
    # Now we use the history stored by Skorch throughout training to plot
    # accuracy and loss curves.
    #

    import matplotlib.pyplot as plt
    from matplotlib.lines import Line2D
    import pandas as pd
    # Extract loss and accuracy values for plotting from history object
    results_columns = [
        'train_loss', 'valid_loss', 'train_accuracy', 'valid_accuracy'
    ]
    df = pd.DataFrame(clf.history[:, results_columns],
                      columns=results_columns,
                      index=clf.history[:, 'epoch'])

    # get percent of misclass for better visual comparison to loss
    df = df.assign(train_misclass=100 - 100 * df.train_accuracy,
                   valid_misclass=100 - 100 * df.valid_accuracy)

    plt.style.use('seaborn')
    fig, ax1 = plt.subplots(figsize=(8, 3))
    df.loc[:, ['train_loss', 'valid_loss']].plot(ax=ax1,
                                                 style=['-', ':'],
                                                 marker='o',
                                                 color='tab:blue',
                                                 legend=False,
                                                 fontsize=14)

    ax1.tick_params(axis='y', labelcolor='tab:blue', labelsize=14)
    ax1.set_ylabel("Loss", color='tab:blue', fontsize=14)

    ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

    df.loc[:, ['train_misclass', 'valid_misclass']].plot(ax=ax2,
                                                         style=['-', ':'],
                                                         marker='o',
                                                         color='tab:red',
                                                         legend=False)
    ax2.tick_params(axis='y', labelcolor='tab:red', labelsize=14)
    ax2.set_ylabel("Misclassification Rate [%]", color='tab:red', fontsize=14)
    ax2.set_ylim(ax2.get_ylim()[0], 85)  # make some room for legend
    ax1.set_xlabel("Epoch", fontsize=14)

    # where some data has already been plotted to ax
    handles = []
    handles.append(
        Line2D([0], [0],
               color='black',
               linewidth=1,
               linestyle='-',
               label='Train'))
    handles.append(
        Line2D([0], [0],
               color='black',
               linewidth=1,
               linestyle=':',
               label='Valid'))
    plt.legend(handles, [h.get_label() for h in handles], fontsize=14)
    plt.tight_layout()

    # plt.show()

    return df
Ejemplo n.º 13
0
def get_sleep_stages(windows_dataset):
    sleep_stages = [x[1] for x in BaseConcatDataset(windows_dataset.datasets)]
    return np.array(sleep_stages)
Ejemplo n.º 14
0
def exp(subject_id):
    cuda = torch.cuda.is_available(
    )  # check if GPU is available, if True chooses to use it
    device = 'cuda:1' if cuda else 'cpu'
    if cuda:
        torch.backends.cudnn.benchmark = True
    seed = 10  # random seed to make results reproducible
    # Set random seed to be able to reproduce results
    set_random_seeds(seed=seed, cuda=cuda)

    test_subj = np.r_[subject_id]

    print('test subj:' + str(test_subj))
    train_subj = np.setdiff1d(np.r_[1:10], test_subj)

    tr = []
    val = []

    #10%씩 떼어내서 val만듬
    for ids in train_subj:
        train_size = int(0.9 * len(splitted[ids]))
        test_size = len(splitted[ids]) - train_size
        tr_i, val_i = torch.utils.data.random_split(splitted[ids],
                                                    [train_size, test_size])
        tr.append(tr_i)
        val.append(val_i)

    train_set = torch.utils.data.ConcatDataset(tr)
    valid_set = torch.utils.data.ConcatDataset(val)
    test_set = BaseConcatDataset([splitted[ids] for ids in test_subj])

    # model = Deep4Net(
    #     n_chans,
    #     n_classes,
    #     input_window_samples=input_window_samples,
    #     final_conv_length="auto",
    # )

    crop_size = 1125
    embedding_net = EEGNet_v2_old(n_classes, n_chans, crop_size)
    model = FcClfNet(embedding_net)

    print(model)

    epochs = 100

    batch_size = 64

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=batch_size,
                                               shuffle=True)
    valid_loader = torch.utils.data.DataLoader(valid_set,
                                               batch_size=batch_size,
                                               shuffle=False)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=batch_size,
                                              shuffle=False)

    # Send model to GPU
    if cuda:
        model.cuda(device=device)

    from torch.optim import lr_scheduler
    import torch.optim as optim

    import argparse
    parser = argparse.ArgumentParser(
        description='cross subject domain adaptation')
    parser.add_argument('--batch-size',
                        type=int,
                        default=50,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=50,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model',
                        action='store_true',
                        default=True,
                        help='For Saving the current Model')
    args = parser.parse_args()
    args.gpuidx = 1
    args.seed = 0
    args.use_tensorboard = False
    args.save_model = False

    optimizer = optim.AdamW(model.parameters(),
                            lr=0.001,
                            weight_decay=0.5 * 0.001)
    # scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
    scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50)
    #
    # #test lr
    # lr = []
    # for i in range(200):
    #     scheduler.step()
    #     lr.append(scheduler.get_lr())
    #
    # import matplotlib.pyplot as plt
    # plt.plot(lr)

    import pandas as pd
    results_columns = [
        'val_loss', 'test_loss', 'val_accuracy', 'test_accuracy'
    ]
    df = pd.DataFrame(columns=results_columns)

    for epochidx in range(1, epochs):
        print(epochidx)
        train(10, model, device, train_loader, optimizer, scheduler, cuda,
              device)
        val_loss, val_score = eval(model, device, valid_loader)
        test_loss, test_score = eval(model, device, test_loader)
        results = {
            'val_loss': val_loss,
            'test_loss': test_loss,
            'val_accuracy': val_score,
            'test_accuracy': test_score
        }
        df = df.append(results, ignore_index=True)
        print(results)

    return df
Ejemplo n.º 15
0
def create_from_X_y(X,
                    y,
                    drop_last_window,
                    sfreq=None,
                    ch_names=None,
                    window_size_samples=None,
                    window_stride_samples=None):
    """Create a BaseConcatDataset of WindowsDatasets from X and y to be used for
    decoding with skorch and braindecode, where X is a list of pre-cut trials
    and y are corresponding targets.

    Parameters
    ----------
    X: array-like
        list of pre-cut trials as n_trials x n_channels x n_times
    y: array-like
        targets corresponding to the trials
    sfreq: common sampling frequency of all trials
    ch_names: array-like
        channel names of the trials
    drop_last_window: bool
        whether or not have a last overlapping window, when
        windows/windows do not equally divide the continuous signal
    window_size_samples: int
        window size
    window_stride_samples: int
        stride between windows

    Returns
    -------
    windows_datasets: BaseConcatDataset
        X and y transformed to a dataset format that is compatible with skorch
        and braindecode
    """
    n_samples_per_x = []
    base_datasets = []
    if sfreq is None:
        sfreq = 100
        log.info("No sampling frequency given, set to 100 Hz.")
    if ch_names is None:
        ch_names = [str(i) for i in range(X.shape[1])]
        log.info(f"No channel names given, set to 0-{X.shape[1]}).")

    for x, target in zip(X, y):
        n_samples_per_x.append(x.shape[1])
        info = mne.create_info(ch_names=ch_names, sfreq=sfreq)
        raw = mne.io.RawArray(x, info)
        base_dataset = BaseDataset(raw,
                                   pd.Series({"target": target}),
                                   target_name="target")
        base_datasets.append(base_dataset)
    base_datasets = BaseConcatDataset(base_datasets)

    if window_size_samples is None and window_stride_samples is None:
        if not len(np.unique(n_samples_per_x)) == 1:
            raise ValueError(f"if 'window_size_samples' and "
                             f"'window_stride_samples' are None, "
                             f"all trials have to have the same length")
        window_size_samples = n_samples_per_x[0]
        window_stride_samples = n_samples_per_x[0]
    windows_datasets = create_fixed_length_windows(
        base_datasets,
        start_offset_samples=0,
        stop_offset_samples=0,
        window_size_samples=window_size_samples,
        window_stride_samples=window_stride_samples,
        drop_last_window=drop_last_window)
    return windows_datasets
Ejemplo n.º 16
0
def exp(subject_id):
    test_subj = np.r_[subject_id]
    print('test subj:' + str(test_subj))
    train_subj = np.setdiff1d(np.r_[1:11], test_subj)

    train_set = BaseConcatDataset([splitted[ids] for ids in train_subj])
    valid_set = BaseConcatDataset([splitted[ids] for ids in test_subj])

    # #
    model = ShallowFBCSPNet(
        n_chans,
        n_classes,
        input_window_samples=input_window_samples,
        final_conv_length=30,
    )

    # #
    # embedding_net = Deep4Net_origin(4, 22, input_window_samples)
    # model = FcClfNet(embedding_net)

    print(model)

    # Send model to GPU
    if cuda:
        model.cuda()
    from braindecode.models.util import to_dense_prediction_model, get_output_shape
    to_dense_prediction_model(model)
    ######################################################################
    # Training
    # --------
    #

    ######################################################################
    # In difference to trialwise decoding, we now should supply
    # ``cropped=True`` to the EEGClassifier, and ``CroppedLoss`` as the
    # criterion, as well as ``criterion__loss_function`` as the loss function
    # applied to the meaned predictions.
    #

    ######################################################################
    # .. note::
    #    In this tutorial, we use some default parameters that we
    #    have found to work well for motor decoding, however we strongly
    #    encourage you to perform your own hyperparameter optimization using
    #    cross validation on your training data.
    #

    from skorch.callbacks import LRScheduler
    from skorch.helper import predefined_split

    from braindecode import EEGClassifier
    from braindecode.training.losses import CroppedLoss
    from braindecode.training.scoring import trial_preds_from_window_preds

    # # These values we found good for shallow network:
    lr = 0.0625 * 0.01
    weight_decay = 0

    # # For deep4 they should be:
    # lr = 1 * 0.01
    # weight_decay = 0.5 * 0.001

    batch_size = 400
    n_epochs = 100

    clf = EEGClassifier(
        model,
        cropped=True,
        criterion=CroppedLoss,
        criterion__loss_function=torch.nn.functional.nll_loss,
        optimizer=torch.optim.AdamW,
        train_split=predefined_split(valid_set),
        optimizer__lr=lr,
        optimizer__weight_decay=weight_decay,
        iterator_train__shuffle=True,
        batch_size=batch_size,
        callbacks=[
            "accuracy",
            ("lr_scheduler",
             LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),
        ],
        device=device,
    )
    # Model training for a specified number of epochs. `y` is None as it is already supplied
    # in the dataset.
    clf.fit(train_set, y=None, epochs=n_epochs)

    ######################################################################
    # Plot Results
    # ------------
    #

    ######################################################################
    # This is again the same code as in trialwise decoding.
    #
    # .. note::
    #     Note that we drop further in the classification error and
    #     loss as in the trialwise decoding tutorial.
    #

    import matplotlib.pyplot as plt
    from matplotlib.lines import Line2D
    import pandas as pd
    # Extract loss and accuracy values for plotting from history object
    results_columns = [
        'train_loss', 'valid_loss', 'train_accuracy', 'valid_accuracy'
    ]
    df = pd.DataFrame(clf.history[:, results_columns],
                      columns=results_columns,
                      index=clf.history[:, 'epoch'])

    # get percent of misclass for better visual comparison to loss
    df = df.assign(train_misclass=100 - 100 * df.train_accuracy,
                   valid_misclass=100 - 100 * df.valid_accuracy)

    plt.style.use('seaborn')
    fig, ax1 = plt.subplots(figsize=(8, 3))
    df.loc[:, ['train_loss', 'valid_loss']].plot(ax=ax1,
                                                 style=['-', ':'],
                                                 marker='o',
                                                 color='tab:blue',
                                                 legend=False,
                                                 fontsize=14)

    ax1.tick_params(axis='y', labelcolor='tab:blue', labelsize=14)
    ax1.set_ylabel("Loss", color='tab:blue', fontsize=14)

    ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

    df.loc[:, ['train_misclass', 'valid_misclass']].plot(ax=ax2,
                                                         style=['-', ':'],
                                                         marker='o',
                                                         color='tab:red',
                                                         legend=False)
    ax2.tick_params(axis='y', labelcolor='tab:red', labelsize=14)
    ax2.set_ylabel("Misclassification Rate [%]", color='tab:red', fontsize=14)
    ax2.set_ylim(ax2.get_ylim()[0], 85)  # make some room for legend
    ax1.set_xlabel("Epoch", fontsize=14)

    # where some data has already been plotted to ax
    handles = []
    handles.append(
        Line2D([0], [0],
               color='black',
               linewidth=1,
               linestyle='-',
               label='Train'))
    handles.append(
        Line2D([0], [0],
               color='black',
               linewidth=1,
               linestyle=':',
               label='Valid'))
    plt.legend(handles, [h.get_label() for h in handles], fontsize=14)
    plt.tight_layout()

    return df