Beispiel #1
0
def test_tree_one_split(veterans):
    X, y = veterans
    X = X.loc[:, "Karnofsky_score"].values[:, numpy.newaxis]

    tree = SurvivalTree(max_depth=1)
    tree.fit(X, y)

    stats = LogrankTreeBuilder(max_depth=1).build(X, y)

    assert tree.tree_.capacity == stats.shape[0]
    assert_array_equal(tree.tree_.feature, stats.loc[:, "feature"].values)
    assert_array_equal(tree.tree_.n_node_samples,
                       stats.loc[:, "n_node_samples"].values)
    assert_array_almost_equal(tree.tree_.threshold,
                              stats.loc[:, "threshold"].values)

    expected_time = numpy.array([
        1, 2, 3, 4, 7, 8, 10, 11, 12, 13, 15, 16, 18, 19, 20, 21, 22, 24, 25,
        27, 29, 30, 31, 33, 35, 36, 42, 43, 44, 45, 48, 49, 51, 52, 53, 54, 56,
        59, 61, 63, 72, 73, 80, 82, 84, 87, 90, 92, 95, 99, 100, 103, 105, 110,
        111, 112, 117, 118, 122, 126, 132, 133, 139, 140, 143, 144, 151, 153,
        156, 162, 164, 177, 186, 200, 201, 216, 228, 231, 242, 250, 260, 278,
        283, 287, 314, 340, 357, 378, 384, 389, 392, 411, 467, 553, 587, 991,
        999
    ],
                                dtype=float)
    assert_array_equal(tree.event_times_, expected_time)

    threshold = stats.loc[0, "threshold"]
    m = X[:, 0] <= threshold
    y_left = y[m]
    _, chf_left = nelson_aalen_estimator(y_left["Status"],
                                         y_left["Survival_in_days"])

    y_right = y[~m]
    _, chf_right = nelson_aalen_estimator(y_right["Status"],
                                          y_right["Survival_in_days"])

    X_pred = numpy.array([[threshold - 10], [threshold + 10]])
    chf_pred = tree.predict_cumulative_hazard_function(X_pred,
                                                       return_array=True)

    assert_curve_almost_equal(chf_pred[0], chf_left)
    assert_curve_almost_equal(chf_pred[1], chf_right)

    mrt_pred = tree.predict(X_pred)
    assert_array_almost_equal(mrt_pred, numpy.array([196.55878, 86.14939]))

    _, surv_left = kaplan_meier_estimator(y_left["Status"],
                                          y_left["Survival_in_days"])
    _, surv_right = kaplan_meier_estimator(y_right["Status"],
                                           y_right["Survival_in_days"])

    surv_pred = tree.predict_survival_function(X_pred, return_array=True)

    assert_curve_almost_equal(surv_pred[0], surv_left)
    assert_curve_almost_equal(surv_pred[1], surv_right)
Beispiel #2
0
def test_nelson_aalen_all_uncensored():
    time = [1, 2, 2, 3, 7, 6, 5, 5, 3, 9, 11, 23, 17, 13, 6, 13]
    event = numpy.repeat(True, len(time))

    true_x = numpy.array([1, 2, 3, 5, 6, 7, 9, 11, 13, 17, 23])
    true_y = numpy.array([0.0625, 0.195833333333333, 0.349679487179487, 0.531497668997669, 0.753719891219891,
                          0.896577034077034, 1.0632437007437, 1.2632437007437, 1.7632437007437, 2.2632437007437,
                          3.2632437007437])

    x, y = nelson_aalen_estimator(event, time)

    assert_array_equal(x, true_x)
    assert_array_almost_equal(y, true_y)
Beispiel #3
0
def test_nelson_aalen_first_censored():
    time = [1, 2, 2, 3, 7, 6, 5, 5, 3, 9, 11, 13, 17, 13, 6, 23]
    event = numpy.repeat(True, len(time))
    event[0] = False

    true_x = numpy.array([1, 2, 3, 5, 6, 7, 9, 11, 13, 17, 23])
    true_y = numpy.array([0, 0.133333333333333, 0.287179487179487, 0.468997668997669, 0.691219891219891,
                          0.834077034077034, 1.0007437007437, 1.2007437007437, 1.7007437007437, 2.2007437007437,
                          3.2007437007437])

    x, y = nelson_aalen_estimator(event, time)

    assert_array_equal(x, true_x)
    assert_array_almost_equal(y, true_y)
Beispiel #4
0
    def test_whas500(make_whas500):
        whas500 = make_whas500(with_mean=False, with_std=False)
        time = whas500.y['lenfol']
        event = whas500.y['fstat']

        x, y = nelson_aalen_estimator(event, time)

        true_x = numpy.array([
            1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 16, 17, 18, 19, 20, 22, 26, 31,
            32, 33, 34, 37, 42, 46, 49, 52, 53, 55, 57, 60, 61, 62, 64, 69, 76,
            81, 83, 88, 91, 93, 95, 97, 100, 101, 108, 109, 113, 116, 117, 118,
            129, 132, 134, 135, 137, 140, 143, 145, 146, 151, 166, 169, 187,
            192, 197, 200, 226, 233, 235, 259, 269, 274, 287, 289, 295, 297,
            312, 313, 321, 328, 343, 345, 354, 358, 359, 363, 368, 371, 373,
            376, 382, 385, 386, 390, 392, 397, 398, 399, 400, 403, 405, 406,
            407, 408, 411, 412, 416, 418, 419, 421, 422, 424, 426, 427, 433,
            437, 440, 442, 445, 446, 449, 450, 451, 452, 457, 458, 459, 465,
            466, 467, 473, 475, 478, 479, 480, 486, 497, 506, 507, 510, 511,
            516, 519, 521, 522, 523, 524, 529, 530, 532, 535, 537, 542, 544,
            550, 551, 552, 554, 559, 562, 568, 570, 573, 578, 587, 589, 606,
            609, 612, 614, 626, 631, 632, 644, 646, 649, 654, 659, 662, 670,
            673, 675, 704, 714, 718, 725, 849, 865, 903, 905, 920, 936, 953,
            1048, 1054, 1065, 1096, 1098, 1102, 1103, 1105, 1106, 1107, 1108,
            1109, 1114, 1117, 1121, 1123, 1125, 1126, 1136, 1140, 1150, 1151,
            1152, 1157, 1159, 1160, 1161, 1162, 1163, 1165, 1169, 1170, 1174,
            1178, 1182, 1187, 1189, 1190, 1191, 1196, 1199, 1200, 1203, 1207,
            1211, 1217, 1223, 1224, 1231, 1232, 1233, 1234, 1235, 1244, 1245,
            1248, 1251, 1253, 1256, 1257, 1262, 1265, 1266, 1272, 1273, 1274,
            1277, 1279, 1280, 1290, 1295, 1298, 1302, 1308, 1314, 1317, 1319,
            1320, 1325, 1329, 1332, 1333, 1336, 1338, 1346, 1347, 1353, 1359,
            1363, 1365, 1366, 1374, 1377, 1378, 1381, 1384, 1385, 1388, 1390,
            1400, 1408, 1409, 1420, 1430, 1433, 1438, 1444, 1449, 1451, 1454,
            1456, 1458, 1496, 1506, 1527, 1536, 1548, 1553, 1576, 1577, 1579,
            1624, 1627, 1671, 1831, 1836, 1847, 1854, 1858, 1863, 1880, 1883,
            1885, 1887, 1889, 1893, 1899, 1904, 1914, 1919, 1920, 1923, 1926,
            1931, 1933, 1934, 1936, 1939, 1940, 1941, 1942, 1954, 1955, 1964,
            1969, 1976, 1977, 1979, 1993, 1994, 2006, 2009, 2025, 2032, 2048,
            2057, 2061, 2064, 2065, 2066, 2083, 2084, 2086, 2100, 2108, 2113,
            2114, 2118, 2122, 2123, 2125, 2126, 2131, 2132, 2139, 2145, 2146,
            2151, 2152, 2156, 2160, 2166, 2168, 2172, 2173, 2175, 2178, 2190,
            2192, 2350, 2353, 2358
        ])

        assert_array_equal(x.astype(numpy.int_), true_x)

        true_y = numpy.array([
            0.016, 0.032260162601626, 0.038458509709064, 0.0426165138670682,
            0.0467918792115359, 0.0572740595050369, 0.0699859239118166,
            0.0764236921521598, 0.0850630010074515, 0.089420299482397,
            0.0916084832898368, 0.0959944482021175, 0.102602377717536,
            0.109254262418201, 0.113718548132487, 0.118202853065222,
            0.120455105317475, 0.12271244166059, 0.127237328538418,
            0.1340555103566, 0.13634383987605, 0.138637417857702,
            0.140936268432414, 0.143240415897852, 0.145549884720023,
            0.147864699534838, 0.150184885149687, 0.152510466545036,
            0.157172471207041, 0.159514391581748, 0.161861809422123,
            0.164214750598594, 0.16893173173067, 0.173671068223561,
            0.176052020604513, 0.178438655449382, 0.180830999946989,
            0.183229081481762, 0.185632927635608, 0.188042566189824,
            0.190458025127023, 0.192879332633076, 0.195306517099095,
            0.197739607123426, 0.20017863151367, 0.202623619288731,
            0.205074599680888, 0.20753160213789, 0.209994656325083,
            0.212463792127552, 0.214939039652304, 0.217420429230468,
            0.219907991419523, 0.222401757005558, 0.224901757005558,
            0.229914288333878, 0.232433180021536, 0.234958432546788,
            0.237490078116409, 0.240028149182399, 0.242572678444485,
            0.247674719260812, 0.252802924389017, 0.255380243976646,
            0.257964223304811, 0.26055489687994, 0.263152299477343,
            0.26575646614401, 0.268367432201451, 0.273603034295692,
            0.27623461324306, 0.278873135670501, 0.281518638316003,
            0.284171158209903, 0.286830732677988, 0.292164066011321,
            0.294845031158774, 0.297533203201785, 0.300228620991542,
            0.302931323694245, 0.305641350794516, 0.308358742098864,
            0.311083537739191, 0.313815778176349, 0.319295230231144,
            0.322050051167783, 0.322050051167783, 0.322050051167783,
            0.322050051167783, 0.322050051167783, 0.324866952576234,
            0.327691811333296, 0.327691811333296, 0.327691811333296,
            0.330548954190439, 0.333414283703333, 0.333414283703333,
            0.333414283703333, 0.333414283703333, 0.333414283703333,
            0.33633826031152, 0.339270811631168, 0.339270811631168,
            0.339270811631168, 0.339270811631168, 0.339270811631168,
            0.339270811631168, 0.339270811631168, 0.342273814634171,
            0.342273814634171, 0.345294962670425, 0.345294962670425,
            0.345294962670425, 0.345294962670425, 0.345294962670425,
            0.345294962670425, 0.345294962670425, 0.348390937902623,
            0.348390937902623, 0.351545512035115, 0.351545512035115,
            0.351545512035115, 0.351545512035115, 0.351545512035115,
            0.351545512035115, 0.351545512035115, 0.351545512035115,
            0.354824200559705, 0.354824200559705, 0.358124530592708,
            0.361435788870854, 0.361435788870854, 0.361435788870854,
            0.364780270476205, 0.364780270476205, 0.364780270476205,
            0.368158648854584, 0.368158648854584, 0.368158648854584,
            0.368158648854584, 0.368158648854584, 0.368158648854584,
            0.368158648854584, 0.368158648854584, 0.368158648854584,
            0.368158648854584, 0.368158648854584, 0.368158648854584,
            0.371717367715794, 0.371717367715794, 0.375314490017952,
            0.378937678423749, 0.382574042060113, 0.382574042060113,
            0.382574042060113, 0.382574042060113, 0.386305385343695,
            0.386305385343695, 0.390064783839935, 0.393838368745596,
            0.393838368745596, 0.393838368745596, 0.393838368745596,
            0.393838368745596, 0.393838368745596, 0.393838368745596,
            0.393838368745596, 0.393838368745596, 0.397775376619611,
            0.401727945789572, 0.401727945789572, 0.401727945789572,
            0.405727945789572, 0.4097440100466, 0.413776268111116,
            0.417824851107068, 0.421889891757474, 0.421889891757474,
            0.421889891757474, 0.426005118094923, 0.430137349499881,
            0.430137349499881, 0.434304016166548, 0.438488116584958,
            0.442689797257227, 0.442689797257227, 0.44692708539282,
            0.451182404541756, 0.45545590881526, 0.459747754308823,
            0.464058099136409, 0.468387103465413, 0.47273492955237,
            0.477101741779444, 0.481487706691725, 0.485892993035337,
            0.490317771796399, 0.490317771796399, 0.490317771796399,
            0.490317771796399, 0.490317771796399, 0.490317771796399,
            0.490317771796399, 0.490317771796399, 0.490317771796399,
            0.490317771796399, 0.490317771796399, 0.490317771796399,
            0.490317771796399, 0.490317771796399, 0.490317771796399,
            0.495125464104092, 0.495125464104092, 0.495125464104092,
            0.495125464104092, 0.500075959153596, 0.500075959153596,
            0.505075959153596, 0.505075959153596, 0.505075959153596,
            0.505075959153596, 0.505075959153596, 0.510204164281801,
            0.510204164281801, 0.510204164281801, 0.515412497615135,
            0.515412497615135, 0.515412497615135, 0.515412497615135,
            0.515412497615135, 0.515412497615135, 0.515412497615135,
            0.515412497615135, 0.515412497615135, 0.520937359493588,
            0.520937359493588, 0.520937359493588, 0.520937359493588,
            0.526587077007712, 0.526587077007712, 0.526587077007712,
            0.526587077007712, 0.532367423828521, 0.538249776769698,
            0.538249776769698, 0.538249776769698, 0.538249776769698,
            0.538249776769698, 0.538249776769698, 0.538249776769698,
            0.538249776769698, 0.538249776769698, 0.538249776769698,
            0.538249776769698, 0.538249776769698, 0.538249776769698,
            0.538249776769698, 0.538249776769698, 0.538249776769698,
            0.538249776769698, 0.544872293325989, 0.544872293325989,
            0.544872293325989, 0.544872293325989, 0.544872293325989,
            0.544872293325989, 0.544872293325989, 0.544872293325989,
            0.551914546847116, 0.551914546847116, 0.551914546847116,
            0.551914546847116, 0.551914546847116, 0.551914546847116,
            0.551914546847116, 0.551914546847116, 0.551914546847116,
            0.551914546847116, 0.551914546847116, 0.551914546847116,
            0.559788562595148, 0.559788562595148, 0.559788562595148,
            0.559788562595148, 0.559788562595148, 0.576182005218098,
            0.576182005218098, 0.576182005218098, 0.576182005218098,
            0.576182005218098, 0.576182005218098, 0.576182005218098,
            0.576182005218098, 0.576182005218098, 0.576182005218098,
            0.576182005218098, 0.576182005218098, 0.576182005218098,
            0.576182005218098, 0.576182005218098, 0.576182005218098,
            0.576182005218098, 0.576182005218098, 0.576182005218098,
            0.576182005218098, 0.586182005218098, 0.596283015319109,
            0.606487096951762, 0.616796375302277, 0.627213041968944,
            0.637739357758417, 0.648377655630758, 0.659130343802801,
            0.669999909020192, 0.680988920009203, 0.692100031120314,
            0.703335986176494, 0.703335986176494, 0.703335986176494,
            0.703335986176494, 0.703335986176494, 0.703335986176494,
            0.703335986176494, 0.703335986176494, 0.703335986176494,
            0.703335986176494, 0.703335986176494, 0.703335986176494,
            0.703335986176494, 0.703335986176494, 0.703335986176494,
            0.703335986176494, 0.703335986176494, 0.703335986176494,
            0.703335986176494, 0.718261359310822, 0.718261359310822,
            0.718261359310822, 0.718261359310822, 0.718261359310822,
            0.718261359310822, 0.718261359310822, 0.718261359310822,
            0.718261359310822, 0.735805218959945, 0.735805218959945,
            0.735805218959945, 0.735805218959945, 0.735805218959945,
            0.735805218959945, 0.735805218959945, 0.735805218959945,
            0.735805218959945, 0.735805218959945, 0.735805218959945,
            0.735805218959945, 0.735805218959945, 0.735805218959945,
            0.735805218959945, 0.735805218959945, 0.735805218959945,
            0.735805218959945, 0.735805218959945, 0.735805218959945,
            0.735805218959945, 0.735805218959945, 0.735805218959945,
            0.735805218959945, 0.735805218959945, 0.735805218959945,
            0.735805218959945, 0.735805218959945, 0.735805218959945,
            0.735805218959945, 0.735805218959945, 0.735805218959945,
            0.735805218959945, 0.735805218959945, 0.735805218959945,
            0.735805218959945, 0.735805218959945, 0.735805218959945,
            0.735805218959945, 0.819138552293278, 0.819138552293278,
            0.819138552293278, 0.819138552293278, 0.819138552293278,
            0.819138552293278, 0.819138552293278, 0.819138552293278,
            0.819138552293278, 1.15247188562661, 1.65247188562661,
            2.65247188562661
        ])

        assert_array_almost_equal(y, true_y)
Beispiel #5
0
    def test_simple(simple_data_na):
        time, event, true_x, true_y = simple_data_na
        x, y = nelson_aalen_estimator(event, time)

        assert_array_equal(x, true_x)
        assert_array_almost_equal(y, true_y)