def test_relabel(): """ relabel() works with multi-class labels. """ nb_classes = 3 inputs = np.array([ [True, False, False], [False, True, False], [True, False, True], ]) expected_0 = np.array([True, False, True]) expected_1 = np.array([False, True, False]) expected_2 = np.array([False, False, True]) assert np.array_equal(relabel(inputs, 0, nb_classes), expected_0) assert np.array_equal(relabel(inputs, 1, nb_classes), expected_1) assert np.array_equal(relabel(inputs, 2, nb_classes), expected_2)
def test_relabel_binary(): """ relabel() works with binary classification (no changes to labels) """ nb_classes = 2 inputs = np.array([True, False, False]) assert np.array_equal(relabel(inputs, 0, nb_classes), inputs)