コード例 #1
0
def test_relative_positioning_sampler(windows_ds, same_rec_neg):
    tau_pos, tau_neg = 2000, 3000
    n_examples = 100
    sampler = RelativePositioningSampler(
        windows_ds.get_metadata(), tau_pos=tau_pos, tau_neg=tau_neg,
        n_examples=n_examples, tau_max=None, same_rec_neg=same_rec_neg,
        random_state=33)

    pairs = [pair for pair in sampler]
    pairs_df = pd.DataFrame(pairs, columns=['win_ind1', 'win_ind2', 'y'])
    pairs_df['diff'] = pairs_df.apply(
        lambda x: abs(windows_ds[int(x['win_ind1'])][2][1] -
                      windows_ds[int(x['win_ind2'])][2][1]), axis=1)
    pairs_df['same_rec'] = pairs_df.apply(
        lambda x: (find_dataset_ind(windows_ds, int(x['win_ind1'])) ==
                   find_dataset_ind(windows_ds, int(x['win_ind2']))), axis=1)

    assert len(pairs) == n_examples == len(sampler)
    assert all(pairs_df.loc[pairs_df['y'] == 1, 'diff'] <= tau_pos)
    if same_rec_neg:
        assert all(pairs_df.loc[pairs_df['y'] == 0, 'diff'] >= tau_neg)
        assert all(pairs_df['same_rec'] == same_rec_neg)
    else:
        assert all(pairs_df.loc[pairs_df['y'] == 0, 'same_rec'] == False)  # noqa: E712
        assert all(pairs_df.loc[pairs_df['y'] == 1, 'same_rec'] == True)  # noqa: E712
    assert abs(np.diff(pairs_df['y'].value_counts())) < 20
コード例 #2
0
def test_relative_positioning_sampler_presample(windows_ds):
    tau_pos, tau_neg = 2000, 3000
    n_examples = 100
    sampler = RelativePositioningSampler(
        windows_ds.get_metadata(), tau_pos=tau_pos, tau_neg=tau_neg,
        n_examples=n_examples, tau_max=None, same_rec_neg=True,
        random_state=33)

    sampler.presample()
    assert hasattr(sampler, 'examples')
    assert len(sampler.examples) == n_examples

    pairs = [pair for pair in sampler]
    pairs2 = [pair for pair in sampler]
    assert np.array_equal(sampler.examples, pairs)
    assert np.array_equal(sampler.examples, pairs2)
コード例 #3
0
#
# The samplers also control the number of pairs to be sampled (defined with
# `n_examples`). This number can be large to help regularize the pretext task
# training, for instance 2,000 pairs per recording as in [1]_. Here, we use a
# lower number of 250 pairs per recording to reduce training time.
#

from braindecode.samplers.ssl import RelativePositioningSampler

tau_pos, tau_neg = int(sfreq * 60), int(sfreq * 15 * 60)
n_examples_train = 250 * len(splitted['train'].datasets)
n_examples_valid = 250 * len(splitted['valid'].datasets)
n_examples_test = 250 * len(splitted['test'].datasets)

train_sampler = RelativePositioningSampler(
    splitted['train'].get_metadata(), tau_pos=tau_pos, tau_neg=tau_neg,
    n_examples=n_examples_train, same_rec_neg=True, random_state=random_state)
valid_sampler = RelativePositioningSampler(
    splitted['valid'].get_metadata(), tau_pos=tau_pos, tau_neg=tau_neg,
    n_examples=n_examples_valid, same_rec_neg=True,
    random_state=random_state).presample()
test_sampler = RelativePositioningSampler(
    splitted['test'].get_metadata(), tau_pos=tau_pos, tau_neg=tau_neg,
    n_examples=n_examples_test, same_rec_neg=True,
    random_state=random_state).presample()


######################################################################
# Creating the model
# ------------------
#