def test_cropped_trial_epoch_scoring_none_x_test(): dataset_train = None dataset_valid = None predictions = np.array( [ [[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]], [[1.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]], [[1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]], ] ) y_true = [torch.tensor([0, 0]), torch.tensor([1, 1])] window_inds = [( torch.tensor([0, 0]), # i_window_in_trials [None], # won't be used torch.tensor([4, 4]), # i_window_stops ), ( torch.tensor([0, 0]), # i_window_in_trials [None], # won't be used torch.tensor([4, 4]), # i_window_stops )] cropped_trial_epoch_scoring = CroppedTrialEpochScoring("accuracy") cropped_trial_epoch_scoring.initialize() cropped_trial_epoch_scoring.y_preds_ = [ to_tensor(predictions[:2], device="cpu"), to_tensor(predictions[2:], device="cpu"), ] cropped_trial_epoch_scoring.y_trues_ = y_true cropped_trial_epoch_scoring.window_inds_ = window_inds mock_skorch_net = MockSkorchNet() mock_skorch_net.callbacks_ = [( "", cropped_trial_epoch_scoring)] output = cropped_trial_epoch_scoring.on_epoch_end( mock_skorch_net, dataset_train, dataset_valid ) assert output is None
def test_cropped_trial_epoch_scoring(): dataset_train = None # Definition of test cases predictions_cases = [ # Exepected predictions classification results: [1, 0, 0, 0] np.array( [ [[0.2, 0.1, 0.1, 0.1], [0.8, 0.9, 0.9, 0.9]], # trial 0 preds [[1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]], # trial 1 preds [[1.0, 1.0, 1.0, 0.2], [0.0, 0.0, 0.0, 0.8]], # trial 2 preds [[0.9, 0.8, 0.9, 1.0], [0.1, 0.2, 0.1, 0.0]], # trial 3 preds ] ), # Expected predictions classification results: [1, 1, 1, 0] np.array( [ [[0.2, 0.1, 0.1, 0.1], [0.8, 0.9, 0.9, 0.9]], [[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]], [[0.0, 0.0, 0.0, 0.2], [1.0, 1.0, 1.0, 0.8]], [[0.9, 0.8, 0.9, 1.0], [0.1, 0.2, 0.1, 0.0]], ] ), ] y_true_cases = [ [torch.tensor([0, 0]), torch.tensor([1, 1])], [torch.tensor([1, 1]), torch.tensor([1, 1])], ] expected_accuracies_cases = [0.25, 0.75] window_inds = [( torch.tensor([0,0]), # i_window_in_trials [None],# won't be used torch.tensor([4,4]), # i_window_stops ),( torch.tensor([0,0]), # i_window_in_trials [None],# won't be used torch.tensor([4,4]), # i_window_stops ),] for predictions, y_true, accuracy in zip( predictions_cases, y_true_cases, expected_accuracies_cases ): dataset_valid = create_from_X_y( np.zeros((4, 1, 10)), np.concatenate(y_true), window_size_samples=10, window_stride_samples=4, drop_last_window=False) mock_skorch_net = MockSkorchNet() cropped_trial_epoch_scoring = CroppedTrialEpochScoring( "accuracy", on_train=False) mock_skorch_net.callbacks = [( "", cropped_trial_epoch_scoring)] cropped_trial_epoch_scoring.initialize() cropped_trial_epoch_scoring.y_preds_ = [ to_tensor(predictions[:2], device="cpu"), to_tensor(predictions[2:], device="cpu"), ] cropped_trial_epoch_scoring.y_trues_ = y_true cropped_trial_epoch_scoring.window_inds_ = window_inds cropped_trial_epoch_scoring.on_epoch_end( mock_skorch_net, dataset_train, dataset_valid ) np.testing.assert_almost_equal( mock_skorch_net.history[0]["accuracy"], accuracy )