예제 #1
0
def test_remaining_obs_put_in_last_slice():
    """This was a bug found in testing. This particular dataset puts 9
    counts"""
    y = np.arange(100)

    slice_indicator, counts = slice_y(y, n_slices=11)

    np.testing.assert_array_equal(np.hstack((np.repeat(9, 10), 10)), counts)

    expected_slices = np.hstack((np.repeat(np.arange(10),
                                           9), np.repeat(10, 10)))
    np.testing.assert_array_equal(expected_slices, slice_indicator)
예제 #2
0
def test_uneven_slices_within_y():
    """Make sure the slice function actually groups common y values together.
    """
    y = np.hstack((np.repeat(1, 13), np.repeat(2, 5), np.repeat(3, 25)))

    slice_indicator, counts = slice_y(y, n_slices=3)

    np.testing.assert_array_equal(counts, np.array([13, 5, 25]))

    expected_indicator = np.hstack(
        (np.repeat(0, 13), np.repeat(1, 5), np.repeat(2, 25)))
    np.testing.assert_array_equal(slice_indicator, expected_indicator)
예제 #3
0
def test_uneven_slices():
    """Too match DR the slice function spills over to 6 slices if it cannot
    fit it into 5..."""
    shuffle = module_rng.permutation(68)
    y = np.arange(68)[shuffle]

    slice_indicator, counts = slice_y(y, n_slices=5)

    np.testing.assert_array_equal(counts, np.array([13, 13, 13, 13, 13, 3]))

    np.testing.assert_array_equal(
        slice_indicator,
        np.hstack((np.repeat(np.arange(5), 13), np.array([5, 5, 5]))))
예제 #4
0
def test_slice_y_simple():
    """Test a simple a target with 100 unique values and even slices."""
    shuffle = module_rng.permutation(100)
    y = np.arange(100)[shuffle]

    slice_indicator, counts = slice_y(y, n_slices=10)

    # all slices should have 10 counts
    np.testing.assert_array_equal(counts, np.repeat(10, 10))

    # continguous indices counting up from 10
    np.testing.assert_array_equal(slice_indicator,
                                  np.repeat(np.arange(10), 10))
예제 #5
0
def test_slice_raises_error_when_y_has_one_value():
    """Error if y only has one value."""
    y = np.ones(100)

    with pytest.raises(ValueError):
        slice_y(y)
예제 #6
0
def test_slice_match_athletes():
    """Test that we match the slices used in the DR package on the
    athlete's dataset.
    """
    X, y = load_athletes()

    slice_indicator, counts = slice_y(y, n_slices=11)

    # counts per slice match DR package
    expected_counts = np.array([18, 18, 18, 18, 18, 19, 18, 19, 23, 18, 15])
    np.testing.assert_array_equal(counts, expected_counts)

    # slice indicators match DR package
    expected_indicator = np.array([
        6,
        5,
        4,
        4,
        2,
        3,
        5,
        1,
        3,
        3,
        8,
        6,
        1,
        7,
        5,
        4,
        3,
        4,
        6,
        4,
        6,
        6,
        4,
        3,
        7,
        6,
        5,
        4,
        1,
        2,
        2,
        5,
        5,
        6,
        6,
        2,
        1,
        3,
        2,
        4,
        4,
        2,
        3,
        4,
        2,
        5,
        4,
        5,
        3,
        1,
        3,
        4,
        3,
        5,
        5,
        3,
        3,
        3,
        4,
        2,
        5,
        5,
        6,
        5,
        4,
        2,
        2,
        2,
        5,
        8,
        6,
        6,
        7,
        5,
        8,
        1,
        3,
        1,
        3,
        1,
        3,
        1,
        2,
        3,
        2,
        1,
        2,
        2,
        2,
        4,
        4,
        1,
        2,
        1,
        2,
        1,
        1,
        1,
        1,
        1,
        5,
        8,
        9,
        10,
        9,
        8,
        8,
        9,
        9,
        10,
        7,
        9,
        11,
        9,
        9,
        9,
        11,
        9,
        10,
        10,
        1,
        10,
        10,
        10,
        10,
        11,
        10,
        9,
        11,
        8,
        10,
        10,
        11,
        11,
        11,
        9,
        10,
        8,
        9,
        8,
        4,
        7,
        7,
        8,
        11,
        10,
        9,
        6,
        7,
        8,
        6,
        5,
        4,
        8,
        3,
        7,
        7,
        9,
        9,
        11,
        9,
        9,
        11,
        7,
        9,
        8,
        7,
        6,
        6,
        6,
        7,
        7,
        6,
        6,
        5,
        11,
        11,
        11,
        10,
        7,
        8,
        9,
        7,
        9,
        8,
        8,
        9,
        7,
        11,
        9,
        9,
        11,
        8,
        10,
        10,
        10,
        7,
        10,
        8,
        7,
        6,
        8,
    ])
    np.testing.assert_array_equal(slice_indicator,
                                  np.sort(expected_indicator) - 1)