Esempio n. 1
0
def test_tree_two_split(veterans):
    X, y = veterans
    X = X.loc[:, "Karnofsky_score"].values[:, numpy.newaxis]

    tree = SurvivalTree(max_depth=2, max_features=1)
    tree.fit(X, y)

    assert tree.tree_.capacity == 7
    assert_array_equal(
        tree.tree_.threshold,
        numpy.array([
            45., 25., TREE_UNDEFINED, TREE_UNDEFINED, 87.5, TREE_UNDEFINED,
            TREE_UNDEFINED
        ]))
    expected_size = numpy.array([X.shape[0], 38, 8, 30, 99, 91, 8])
    assert_array_equal(tree.tree_.n_node_samples, expected_size)

    X_pred = numpy.array(
        [66.05, 87.91, 45.62, 40.18, 50.65, 71.24, 96.21, 33.33, 11.57,
         94.28]).reshape(-1, 1)
    mrt_pred = tree.predict(X_pred)
    expected_risk = numpy.array([
        96.7044629620645, 19.6309523809524, 96.7044629620645, 179.264571990757,
        96.7044629620645, 96.7044629620645, 19.6309523809524, 179.264571990757,
        214.027380952381, 19.6309523809524
    ])
    assert_array_almost_equal(mrt_pred, expected_risk)

    chf_pred = tree.predict_cumulative_hazard_function(X_pred,
                                                       return_array=True)
    assert numpy.all(numpy.diff(chf_pred) >= 0)

    surv_pred = tree.predict_survival_function(X_pred, return_array=True)
    assert numpy.all(numpy.diff(surv_pred) <= 0)
Esempio n. 2
0
def test_tree_one_split(veterans):
    X, y = veterans
    X = X.loc[:, "Karnofsky_score"].values[:, numpy.newaxis]

    tree = SurvivalTree(max_depth=1)
    tree.fit(X, y)

    stats = LogrankTreeBuilder(max_depth=1).build(X, y)

    assert tree.tree_.capacity == stats.shape[0]
    assert_array_equal(tree.tree_.feature, stats.loc[:, "feature"].values)
    assert_array_equal(tree.tree_.n_node_samples,
                       stats.loc[:, "n_node_samples"].values)
    assert_array_almost_equal(tree.tree_.threshold,
                              stats.loc[:, "threshold"].values)

    expected_time = numpy.array([
        1, 2, 3, 4, 7, 8, 10, 11, 12, 13, 15, 16, 18, 19, 20, 21, 22, 24, 25,
        27, 29, 30, 31, 33, 35, 36, 42, 43, 44, 45, 48, 49, 51, 52, 53, 54, 56,
        59, 61, 63, 72, 73, 80, 82, 84, 87, 90, 92, 95, 99, 100, 103, 105, 110,
        111, 112, 117, 118, 122, 126, 132, 133, 139, 140, 143, 144, 151, 153,
        156, 162, 164, 177, 186, 200, 201, 216, 228, 231, 242, 250, 260, 278,
        283, 287, 314, 340, 357, 378, 384, 389, 392, 411, 467, 553, 587, 991,
        999
    ],
                                dtype=float)
    assert_array_equal(tree.event_times_, expected_time)

    threshold = stats.loc[0, "threshold"]
    m = X[:, 0] <= threshold
    y_left = y[m]
    _, chf_left = nelson_aalen_estimator(y_left["Status"],
                                         y_left["Survival_in_days"])

    y_right = y[~m]
    _, chf_right = nelson_aalen_estimator(y_right["Status"],
                                          y_right["Survival_in_days"])

    X_pred = numpy.array([[threshold - 10], [threshold + 10]])
    chf_pred = tree.predict_cumulative_hazard_function(X_pred,
                                                       return_array=True)

    assert_curve_almost_equal(chf_pred[0], chf_left)
    assert_curve_almost_equal(chf_pred[1], chf_right)

    mrt_pred = tree.predict(X_pred)
    assert_array_almost_equal(mrt_pred, numpy.array([196.55878, 86.14939]))

    _, surv_left = kaplan_meier_estimator(y_left["Status"],
                                          y_left["Survival_in_days"])
    _, surv_right = kaplan_meier_estimator(y_right["Status"],
                                           y_right["Survival_in_days"])

    surv_pred = tree.predict_survival_function(X_pred, return_array=True)

    assert_curve_almost_equal(surv_pred[0], surv_left)
    assert_curve_almost_equal(surv_pred[1], surv_right)