예제 #1
0
def test_logrank_veterans():
    X, y = load_veterans_lung_cancer()

    chisq, pval, stats, covar = compare_survival(y,
                                                 X.loc[:, "Celltype"],
                                                 return_stats=True)

    expected_stats = pandas.DataFrame(
        columns=["counts", "observed", "expected", "statistic"],
        index=["adeno", "large", "smallcell", "squamous"])
    expected_stats.index.name = "group"
    expected_stats["counts"] = numpy.array([27, 27, 48, 35], dtype=numpy.intp)
    expected_stats["observed"] = numpy.array([26, 26, 45, 31],
                                             dtype=numpy.int_)
    expected_stats["expected"] = [
        15.6937646143605, 34.5494783863493, 30.1020793268148, 47.6546776724754
    ]
    expected_stats[
        "statistic"] = expected_stats["observed"] - expected_stats["expected"]
    assert_frame_equal(stats, expected_stats)

    expected_var = numpy.array([[
        12.9661700605014, -4.07011754397142, -4.40872930298506,
        -4.48732321354496
    ],
                                [
                                    -4.07011754397142, 24.1990352938484,
                                    -7.81168661717217, -12.3172311327048
                                ],
                                [
                                    -4.40872930298506, -7.81168661717217,
                                    21.7542679406138, -9.53385202045655
                                ],
                                [
                                    -4.48732321354496, -12.3172311327048,
                                    -9.53385202045655, 26.3384063667063
                                ]])

    assert_almost_equal(covar, expected_var)

    expected_chisq = 25.4037003457854
    expected_pval = 1.27124593900609e-05

    assert round(abs(chisq - expected_chisq), 6) == 0
    assert round(abs(pval - expected_pval), 6) == 0
예제 #2
0
def veterans():
    return load_veterans_lung_cancer()
예제 #3
0
"""
https://scikit-survival.readthedocs.io/en/stable/user_guide/00-introduction.html
"""

from sksurv.datasets import load_veterans_lung_cancer

data_x, data_y = load_veterans_lung_cancer()
data_y

import pandas as pd

pd.DataFrame.from_records(data_y[[11, 5, 32, 13, 23]], index=range(1, 6))

%matplotlib inline
import matplotlib.pyplot as plt
from sksurv.nonparametric import kaplan_meier_estimator

time, survival_prob = kaplan_meier_estimator(data_y["Status"], data_y["Survival_in_days"])
plt.step(time, survival_prob, where="post")
plt.ylabel("est. probability of survival $\hat{S}(t)$")
plt.xlabel("time $t$")

data_x["Treatment"].value_counts()

for treatment_type in ("standard", "test"):
    mask_treat = data_x["Treatment"] == treatment_type
    time_treatment, survival_prob_treatment = kaplan_meier_estimator(
        data_y["Status"][mask_treat],
        data_y["Survival_in_days"][mask_treat])
예제 #4
0
 def test_load_veterans_lung_cancer():
     x, y = sdata.load_veterans_lung_cancer()
     assert x.shape == (137, 6)
     assert y.shape == (137,)
     assert_structured_array_dtype(y, 'Status', 'Survival_in_days', 128)
def plot_cumulative_dynamic_auc(risk_score, label, color=None):
    auc, mean_auc = cumulative_dynamic_auc(y_train, y_test, risk_score, times)

    plt.plot(times, auc, marker="o", color=color, label=label)
    plt.xlabel("days from enrollment")
    plt.ylabel("time-dependent AUC")
    plt.axhline(mean_auc, color=color, linestyle="--")
    plt.legend()


for i, col in enumerate(num_columns):
    plot_cumulative_dynamic_auc(x_test[:, i], col, color="C{}".format(i))
    ret = concordance_index_ipcw(y_train, y_test, x_test[:, i], tau=times[-1])

from sksurv.datasets import load_veterans_lung_cancer

va_x, va_y = load_veterans_lung_cancer()

cph = make_pipeline(OneHotEncoder(), CoxPHSurvivalAnalysis())
cph.fit(va_x, va_y)

va_times = np.arange(7, 183, 7)
# estimate performance on training data, thus use `va_y` twice.
va_auc, va_mean_auc = cumulative_dynamic_auc(va_y, va_y, cph.predict(va_x),
                                             va_times)

plt.plot(va_times, va_auc, marker="o")
plt.axhline(va_mean_auc, linestyle="--")
plt.xlabel("days from enrollment")
plt.ylabel("time-dependent AUC")
plt.grid(True)