def test_trialwise_predict_and_predict_proba():
    preds = np.array([
        [0.125, 0.875],
        [1., 0.],
        [0.8, 0.2],
        [0.9, 0.1],
    ])
    clf = EEGClassifier(MockModule(preds), optimizer=optim.Adam, batch_size=32)
    clf.initialize()
    np.testing.assert_array_equal(preds.argmax(1), clf.predict(MockDataset()))
    np.testing.assert_array_equal(preds, clf.predict_proba(MockDataset()))
def test_cropped_predict_and_predict_proba_not_aggregate_predictions():
    preds = np.array([
        [[0.2, 0.1, 0.1, 0.1], [0.8, 0.9, 0.9, 0.9]],
        [[1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]],
        [[1.0, 1.0, 1.0, 0.2], [0.0, 0.0, 0.0, 0.8]],
        [[0.9, 0.8, 0.9, 1.0], [0.1, 0.2, 0.1, 0.0]],
    ])
    clf = EEGClassifier(MockModule(preds),
                        cropped=True,
                        criterion=CroppedLoss,
                        criterion__loss_function=nll_loss,
                        optimizer=optim.Adam,
                        batch_size=32,
                        aggregate_predictions=False)
    clf.initialize()
    np.testing.assert_array_equal(preds.argmax(1), clf.predict(MockDataset()))
    np.testing.assert_array_equal(preds, clf.predict_proba(MockDataset()))
def test_cropped_predict_and_predict_proba():
    preds = np.array([
        [[0.2, 0.1, 0.1, 0.1], [0.8, 0.9, 0.9, 0.9]],
        [[1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]],
        [[1.0, 1.0, 1.0, 0.2], [0.0, 0.0, 0.0, 0.8]],
        [[0.9, 0.8, 0.9, 1.0], [0.1, 0.2, 0.1, 0.0]],
    ])
    clf = EEGClassifier(MockModule(preds),
                        cropped=True,
                        criterion=CroppedLoss,
                        criterion__loss_function=nll_loss,
                        optimizer=optim.Adam,
                        batch_size=32)
    clf.initialize()
    # for cropped decoding classifier returns one label for each trial (averaged over all crops)
    np.testing.assert_array_equal(
        preds.mean(-1).argmax(1), clf.predict(MockDataset()))
    # for cropped decoding classifier returns values for each trial (average over all crops)
    np.testing.assert_array_equal(preds.mean(-1),
                                  clf.predict_proba(MockDataset()))