def test_tree_two_split(veterans): X, y = veterans X = X.loc[:, "Karnofsky_score"].values[:, numpy.newaxis] tree = SurvivalTree(max_depth=2, max_features=1) tree.fit(X, y) assert tree.tree_.capacity == 7 assert_array_equal( tree.tree_.threshold, numpy.array([ 45., 25., TREE_UNDEFINED, TREE_UNDEFINED, 87.5, TREE_UNDEFINED, TREE_UNDEFINED ])) expected_size = numpy.array([X.shape[0], 38, 8, 30, 99, 91, 8]) assert_array_equal(tree.tree_.n_node_samples, expected_size) X_pred = numpy.array( [66.05, 87.91, 45.62, 40.18, 50.65, 71.24, 96.21, 33.33, 11.57, 94.28]).reshape(-1, 1) mrt_pred = tree.predict(X_pred) expected_risk = numpy.array([ 96.7044629620645, 19.6309523809524, 96.7044629620645, 179.264571990757, 96.7044629620645, 96.7044629620645, 19.6309523809524, 179.264571990757, 214.027380952381, 19.6309523809524 ]) assert_array_almost_equal(mrt_pred, expected_risk) chf_pred = tree.predict_cumulative_hazard_function(X_pred, return_array=True) assert numpy.all(numpy.diff(chf_pred) >= 0) surv_pred = tree.predict_survival_function(X_pred, return_array=True) assert numpy.all(numpy.diff(surv_pred) <= 0)
def test_tree_one_split(veterans): X, y = veterans X = X.loc[:, "Karnofsky_score"].values[:, numpy.newaxis] tree = SurvivalTree(max_depth=1) tree.fit(X, y) stats = LogrankTreeBuilder(max_depth=1).build(X, y) assert tree.tree_.capacity == stats.shape[0] assert_array_equal(tree.tree_.feature, stats.loc[:, "feature"].values) assert_array_equal(tree.tree_.n_node_samples, stats.loc[:, "n_node_samples"].values) assert_array_almost_equal(tree.tree_.threshold, stats.loc[:, "threshold"].values) expected_time = numpy.array([ 1, 2, 3, 4, 7, 8, 10, 11, 12, 13, 15, 16, 18, 19, 20, 21, 22, 24, 25, 27, 29, 30, 31, 33, 35, 36, 42, 43, 44, 45, 48, 49, 51, 52, 53, 54, 56, 59, 61, 63, 72, 73, 80, 82, 84, 87, 90, 92, 95, 99, 100, 103, 105, 110, 111, 112, 117, 118, 122, 126, 132, 133, 139, 140, 143, 144, 151, 153, 156, 162, 164, 177, 186, 200, 201, 216, 228, 231, 242, 250, 260, 278, 283, 287, 314, 340, 357, 378, 384, 389, 392, 411, 467, 553, 587, 991, 999 ], dtype=float) assert_array_equal(tree.event_times_, expected_time) threshold = stats.loc[0, "threshold"] m = X[:, 0] <= threshold y_left = y[m] _, chf_left = nelson_aalen_estimator(y_left["Status"], y_left["Survival_in_days"]) y_right = y[~m] _, chf_right = nelson_aalen_estimator(y_right["Status"], y_right["Survival_in_days"]) X_pred = numpy.array([[threshold - 10], [threshold + 10]]) chf_pred = tree.predict_cumulative_hazard_function(X_pred, return_array=True) assert_curve_almost_equal(chf_pred[0], chf_left) assert_curve_almost_equal(chf_pred[1], chf_right) mrt_pred = tree.predict(X_pred) assert_array_almost_equal(mrt_pred, numpy.array([196.55878, 86.14939])) _, surv_left = kaplan_meier_estimator(y_left["Status"], y_left["Survival_in_days"]) _, surv_right = kaplan_meier_estimator(y_right["Status"], y_right["Survival_in_days"]) surv_pred = tree.predict_survival_function(X_pred, return_array=True) assert_curve_almost_equal(surv_pred[0], surv_left) assert_curve_almost_equal(surv_pred[1], surv_right)