def test_truncated_enter_larger_exit_error(): rnd = numpy.random.RandomState(2016) time_exit = rnd.uniform(1, 100, size=25) time_enter = time_exit + 1 event = rnd.binomial(1, 0.6, size=25).astype(bool) with pytest.raises(ValueError, match="exit time must be larger start time for all samples"): kaplan_meier_estimator(event, time_exit, time_enter)
def test_truncated_reverse_error(): rnd = numpy.random.RandomState(2016) time_exit = rnd.uniform(1, 100, size=25) time_enter = time_exit + 1 event = rnd.binomial(1, 0.6, size=25).astype(bool) with pytest.raises(ValueError, match="The censoring distribution cannot be estimated from left truncated data"): kaplan_meier_estimator(event, time_exit, time_enter, reverse=True)
def test_tree_one_split(veterans): X, y = veterans X = X.loc[:, "Karnofsky_score"].values[:, numpy.newaxis] tree = SurvivalTree(max_depth=1) tree.fit(X, y) stats = LogrankTreeBuilder(max_depth=1).build(X, y) assert tree.tree_.capacity == stats.shape[0] assert_array_equal(tree.tree_.feature, stats.loc[:, "feature"].values) assert_array_equal(tree.tree_.n_node_samples, stats.loc[:, "n_node_samples"].values) assert_array_almost_equal(tree.tree_.threshold, stats.loc[:, "threshold"].values) expected_time = numpy.array([ 1, 2, 3, 4, 7, 8, 10, 11, 12, 13, 15, 16, 18, 19, 20, 21, 22, 24, 25, 27, 29, 30, 31, 33, 35, 36, 42, 43, 44, 45, 48, 49, 51, 52, 53, 54, 56, 59, 61, 63, 72, 73, 80, 82, 84, 87, 90, 92, 95, 99, 100, 103, 105, 110, 111, 112, 117, 118, 122, 126, 132, 133, 139, 140, 143, 144, 151, 153, 156, 162, 164, 177, 186, 200, 201, 216, 228, 231, 242, 250, 260, 278, 283, 287, 314, 340, 357, 378, 384, 389, 392, 411, 467, 553, 587, 991, 999 ], dtype=float) assert_array_equal(tree.event_times_, expected_time) threshold = stats.loc[0, "threshold"] m = X[:, 0] <= threshold y_left = y[m] _, chf_left = nelson_aalen_estimator(y_left["Status"], y_left["Survival_in_days"]) y_right = y[~m] _, chf_right = nelson_aalen_estimator(y_right["Status"], y_right["Survival_in_days"]) X_pred = numpy.array([[threshold - 10], [threshold + 10]]) chf_pred = tree.predict_cumulative_hazard_function(X_pred, return_array=True) assert_curve_almost_equal(chf_pred[0], chf_left) assert_curve_almost_equal(chf_pred[1], chf_right) mrt_pred = tree.predict(X_pred) assert_array_almost_equal(mrt_pred, numpy.array([196.55878, 86.14939])) _, surv_left = kaplan_meier_estimator(y_left["Status"], y_left["Survival_in_days"]) _, surv_right = kaplan_meier_estimator(y_right["Status"], y_right["Survival_in_days"]) surv_pred = tree.predict_survival_function(X_pred, return_array=True) assert_curve_almost_equal(surv_pred[0], surv_left) assert_curve_almost_equal(surv_pred[1], surv_right)
def test_truncated_male(self): data = pandas.read_csv(CHANNING_FILE).query("entry < exit") is_male = data.sex == "Male" time_enter_m = data.loc[is_male, "entry"].values time_exit_m = data.loc[is_male, "exit"].values event_m = data.loc[is_male, "cens"].values == 1 x, y = kaplan_meier_estimator(event_m, time_exit_m, time_enter_m) x_true = numpy.array([ 751, 759, 777, 781, 782, 806, 817, 820, 821, 823, 830, 835, 836, 837, 843, 846, 847, 852, 853, 854, 856, 863, 865, 866, 869, 871, 872, 875, 876, 878, 879, 883, 885, 886, 890, 891, 893, 894, 895, 898, 900, 906, 907, 909, 911, 914, 915, 919, 921, 923, 925, 926, 927, 932, 936, 938, 940, 943, 945, 946, 948, 951, 953, 955, 956, 957, 959, 960, 962, 964, 966, 967, 969, 970, 971, 972, 973, 977, 978, 981, 982, 983, 984, 985, 988, 989, 993, 996, 998, 1001, 1002, 1005, 1006, 1007, 1009, 1010, 1012, 1013, 1015, 1016, 1018, 1020, 1021, 1022, 1023, 1025, 1027, 1029, 1031, 1033, 1036, 1039, 1041, 1043, 1044, 1045, 1046, 1047, 1051, 1053, 1055, 1058, 1059, 1060, 1063, 1064, 1070, 1073, 1080, 1085, 1093, 1094, 1106, 1107, 1118, 1128, 1139, 1153 ]) assert_array_equal(x, x_true) assert_array_equal(y[:3], numpy.array([1., 1., .5])) self.assertTrue((y[3:] == 0).all())
def _get_npi(times): X, y = load_gbsg2() grade = X.loc[:, "tgrade"].map({"I": 1, "II": 2, "III": 3}).astype(int) NPI = 0.2 * X.loc[:, "tsize"] / 10 + 1 + grade NPI[NPI < 3.4] = 1.0 NPI[(NPI >= 3.4) & (NPI <= 5.4)] = 2.0 NPI[NPI > 5.4] = 3.0 preds = numpy.empty((X.shape[0], len(times)), dtype=float) for j, ts in enumerate(times): survs = {} for i in NPI.unique(): idx = numpy.flatnonzero(NPI == i) yi = y[idx] t, s = kaplan_meier_estimator(yi["cens"], yi["time"]) if t[-1] < ts and s[-1] == 0.0: survs[i] = 0.0 else: fn = StepFunction(t, s) survs[i] = fn(ts) preds[:, j] = NPI.map(survs).values return preds, y
def calc_kaplan_meier_variance(event, time): """Variance estimator of Kaplan-Meier survival function. Parameters ---------- event : array-like, shape = (n_samples,) Contains binary event indicators. time : array-like, shape = (n_samples,) Contains event/censoring times. Return ------ uniq_times : array, shape = (n_times,) Unique times. variance : array, shape = (n_times,) variance estimator of kaplan-meier survival function. References ---------- .. [1] """ event, time = check_y_survival(event, time) check_consistent_length(event, time) uniq_times, n_events, n_at_risk = _compute_counts(event, time) prob_survival = kaplan_meier_estimator(event, time)[1] variance = (prob_survival**2) * np.cumsum(n_events / (n_at_risk * (n_at_risk - n_events))) return uniq_times, variance
def test_truncated_male_older_68(make_channing): time_enter_m, time_exit_m, event_m = make_channing('Male') x, y = kaplan_meier_estimator(event_m, time_exit_m, time_enter_m, time_min=68 * 12) x_true = numpy.array( [817, 820, 821, 823, 830, 835, 836, 837, 843, 846, 847, 852, 853, 854, 856, 863, 865, 866, 869, 871, 872, 875, 876, 878, 879, 883, 885, 886, 890, 891, 893, 894, 895, 898, 900, 906, 907, 909, 911, 914, 915, 919, 921, 923, 925, 926, 927, 932, 936, 938, 940, 943, 945, 946, 948, 951, 953, 955, 956, 957, 959, 960, 962, 964, 966, 967, 969, 970, 971, 972, 973, 977, 978, 981, 982, 983, 984, 985, 988, 989, 993, 996, 998, 1001, 1002, 1005, 1006, 1007, 1009, 1010, 1012, 1013, 1015, 1016, 1018, 1020, 1021, 1022, 1023, 1025, 1027, 1029, 1031, 1033, 1036, 1039, 1041, 1043, 1044, 1045, 1046, 1047, 1051, 1053, 1055, 1058, 1059, 1060, 1063, 1064, 1070, 1073, 1080, 1085, 1093, 1094, 1106, 1107, 1118, 1128, 1139, 1153]) assert_array_equal(x, x_true) assert (y[:18] == 1).all() y_true = numpy.array( [0.958333, 0.958333, 0.920000, 0.920000, 0.884615, 0.884615, 0.884615, 0.884615, 0.884615, 0.884615, 0.884615, 0.884615, 0.858597, 0.833344, 0.833344, 0.808092, 0.808092, 0.808092, 0.784324, 0.761256, 0.738187, 0.738187, 0.738187, 0.738187, 0.738187, 0.738187, 0.738187, 0.738187, 0.718236, 0.698285, 0.698285, 0.698285, 0.698285, 0.698285, 0.678889, 0.678889, 0.659492, 0.659492, 0.659492, 0.659492, 0.659492, 0.641173, 0.641173, 0.641173, 0.641173, 0.641173, 0.624300, 0.624300, 0.608692, 0.608692, 0.593085, 0.593085, 0.593085, 0.593085, 0.593085, 0.593085, 0.593085, 0.577877, 0.577877, 0.563060, 0.563060, 0.548623, 0.519748, 0.519748, 0.504898, 0.504898, 0.504898, 0.504898, 0.504898, 0.504898, 0.488611, 0.488611, 0.458073, 0.458073, 0.458073, 0.458073, 0.458073, 0.458073, 0.458073, 0.441713, 0.441713, 0.424724, 0.424724, 0.407735, 0.390746, 0.372139, 0.354418, 0.354418, 0.354418, 0.338308, 0.321393, 0.321393, 0.321393, 0.321393, 0.321393, 0.303538, 0.285682, 0.285682, 0.266637, 0.247591, 0.247591, 0.247591, 0.247591, 0.247591, 0.225083, 0.202575, 0.202575, 0.151931, 0.151931, 0.151931, 0.151931, 0.101287, 0.050644, 0.050644]) assert_array_almost_equal(y[18:], y_true)
def test_simple(simple_data_km): time, event, true_x, true_y = simple_data_km x, y = kaplan_meier_estimator(event, time) assert_array_equal(x, true_x) assert_array_almost_equal(y, true_y)
def test_truncated_female_older_68(): data = pandas.read_csv(CHANNING_FILE).query("entry < exit") is_male = data.sex == "Male" time_enter_f = data.loc[-is_male, "entry"].values time_exit_f = data.loc[-is_male, "exit"].values event_f = data.loc[-is_male, "cens"].values == 1 x, y = kaplan_meier_estimator(event_f, time_exit_f, time_enter_f, time_min=68 * 12) x_true = numpy.array( [818, 819, 820, 821, 822, 823, 824, 825, 827, 828, 829, 830, 831, 833, 834, 835, 836, 837, 838, 839, 840, 841, 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 854, 855, 856, 857, 858, 859, 860, 861, 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 881, 882, 883, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 904, 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 938, 939, 940, 941, 942, 943, 944, 945, 946, 947, 948, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, 968, 969, 970, 971, 972, 973, 975, 976, 977, 978, 979, 981, 982, 983, 984, 985, 986, 987, 988, 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024, 1026, 1027, 1028, 1029, 1030, 1031, 1032, 1033, 1034, 1035, 1036, 1037, 1038, 1039, 1040, 1041, 1042, 1043, 1044, 1047, 1049, 1050, 1051, 1053, 1054, 1055, 1056, 1057, 1059, 1061, 1062, 1063, 1064, 1065, 1068, 1070, 1071, 1072, 1073, 1074, 1080, 1083, 1084, 1085, 1086, 1088, 1089, 1091, 1093, 1096, 1097, 1102, 1105, 1109, 1114, 1115, 1119, 1122, 1131, 1132, 1134, 1140, 1142, 1147, 1152, 1172, 1186, 1192, 1200, 1207]) assert_array_equal(x, x_true) y_true = numpy.array( [1, 1, 1, 1, 0.973684, 0.973684, 0.973684, 0.973684, 0.973684, 0.973684, 0.973684, 0.952968, 0.952968, 0.952968, 0.952968, 0.952968, 0.952968, 0.952968, 0.952968, 0.952968, 0.937345, 0.937345, 0.937345, 0.937345, 0.937345, 0.923355, 0.923355, 0.923355, 0.923355, 0.923355, 0.923355, 0.923355, 0.923355, 0.923355, 0.923355, 0.923355, 0.923355, 0.923355, 0.923355, 0.923355, 0.913426, 0.913426, 0.913426, 0.913426, 0.913426, 0.913426, 0.904383, 0.904383, 0.904383, 0.904383, 0.904383, 0.895930, 0.895930, 0.895930, 0.895930, 0.895930, 0.895930, 0.895930, 0.895930, 0.888402, 0.881120, 0.881120, 0.881120, 0.881120, 0.881120, 0.881120, 0.881120, 0.881120, 0.881120, 0.881120, 0.874826, 0.874826, 0.868621, 0.868621, 0.868621, 0.868621, 0.862752, 0.862752, 0.862752, 0.851325, 0.851325, 0.851325, 0.839821, 0.839821, 0.839821, 0.834259, 0.834259, 0.834259, 0.834259, 0.828697, 0.828697, 0.828697, 0.828697, 0.823281, 0.823281, 0.823281, 0.823281, 0.818070, 0.818070, 0.818070, 0.812989, 0.812989, 0.807844, 0.807844, 0.802763, 0.797682, 0.792601, 0.792601, 0.787709, 0.787709, 0.782992, 0.782992, 0.782992, 0.778159, 0.773355, 0.773355, 0.773355, 0.764093, 0.764093, 0.764093, 0.764093, 0.759463, 0.759463, 0.759463, 0.759463, 0.759463, 0.754745, 0.754745, 0.754745, 0.754745, 0.754745, 0.750086, 0.750086, 0.750086, 0.750086, 0.745339, 0.745339, 0.745339, 0.740561, 0.740561, 0.740561, 0.731419, 0.726847, 0.726847, 0.726847, 0.726847, 0.722247, 0.717647, 0.717647, 0.712987, 0.712987, 0.712987, 0.703727, 0.699067, 0.699067, 0.699067, 0.694311, 0.694311, 0.694311, 0.689489, 0.669929, 0.665075, 0.660255, 0.660255, 0.650546, 0.640689, 0.625904, 0.625904, 0.620897, 0.615849, 0.610801, 0.605711, 0.605711, 0.600621, 0.595574, 0.590526, 0.580256, 0.580256, 0.580256, 0.580256, 0.575029, 0.569655, 0.558700, 0.553222, 0.547690, 0.542101, 0.542101, 0.536395, 0.524982, 0.513570, 0.507598, 0.501482, 0.501482, 0.489100, 0.482830, 0.482830, 0.476718, 0.476718, 0.470606, 0.464414, 0.464414, 0.464414, 0.457873, 0.457873, 0.457873, 0.457873, 0.457873, 0.457873, 0.457873, 0.430537, 0.410356, 0.410356, 0.403516, 0.396559, 0.389478, 0.389478, 0.389478, 0.389478, 0.389478, 0.389478, 0.389478, 0.373250, 0.373250, 0.373250, 0.373250, 0.373250, 0.364363, 0.355254, 0.355254, 0.336556, 0.327460, 0.327460, 0.318104, 0.308748, 0.299100, 0.299100, 0.289130, 0.279160, 0.259220, 0.248851, 0.248851, 0.238031, 0.238031, 0.226696, 0.226696, 0.215901, 0.215901, 0.215901, 0.215901, 0.215901, 0.201508, 0.201508, 0.186007, 0.170507, 0.155006, 0.155006, 0.155006, 0.139506, 0.139506, 0.122067, 0.104629, 0.104629, 0.078472, 0.026157, 0.026157]) assert_array_almost_equal(y, y_true)
def test_truncated_female(make_channing): time_enter_f, time_exit_f, event_f = make_channing('Female') x, y = kaplan_meier_estimator(event_f, time_exit_f, time_enter_f) x_true = numpy.array( [733, 746, 748, 760, 762, 768, 769, 772, 775, 777, 783, 792, 794, 795, 796, 797, 798, 799, 802, 804, 805, 807, 808, 809, 810, 811, 812, 813, 814, 815, 818, 819, 820, 821, 822, 823, 824, 825, 827, 828, 829, 830, 831, 833, 834, 835, 836, 837, 838, 839, 840, 841, 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 854, 855, 856, 857, 858, 859, 860, 861, 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 881, 882, 883, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 904, 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 938, 939, 940, 941, 942, 943, 944, 945, 946, 947, 948, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, 968, 969, 970, 971, 972, 973, 975, 976, 977, 978, 979, 981, 982, 983, 984, 985, 986, 987, 988, 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024, 1026, 1027, 1028, 1029, 1030, 1031, 1032, 1033, 1034, 1035, 1036, 1037, 1038, 1039, 1040, 1041, 1042, 1043, 1044, 1047, 1049, 1050, 1051, 1053, 1054, 1055, 1056, 1057, 1059, 1061, 1062, 1063, 1064, 1065, 1068, 1070, 1071, 1072, 1073, 1074, 1080, 1083, 1084, 1085, 1086, 1088, 1089, 1091, 1093, 1096, 1097, 1102, 1105, 1109, 1114, 1115, 1119, 1122, 1131, 1132, 1134, 1140, 1142, 1147, 1152, 1172, 1186, 1192, 1200, 1207]) assert_array_equal(x, x_true) assert (y[:19] == 1).all() y_true = numpy.array( [0.952381, 0.952381, 0.952381, 0.952381, 0.952381, 0.952381, 0.952381, 0.952381, 0.952381, 0.952381, 0.952381, 0.952381, 0.952381, 0.952381, 0.952381, 0.927318, 0.927318, 0.927318, 0.927318, 0.927318, 0.927318, 0.927318, 0.907588, 0.907588, 0.907588, 0.907588, 0.907588, 0.907588, 0.907588, 0.907588, 0.907588, 0.892710, 0.892710, 0.892710, 0.892710, 0.892710, 0.879386, 0.879386, 0.879386, 0.879386, 0.879386, 0.879386, 0.879386, 0.879386, 0.879386, 0.879386, 0.879386, 0.879386, 0.879386, 0.879386, 0.879386, 0.869930, 0.869930, 0.869930, 0.869930, 0.869930, 0.869930, 0.861317, 0.861317, 0.861317, 0.861317, 0.861317, 0.853267, 0.853267, 0.853267, 0.853267, 0.853267, 0.853267, 0.853267, 0.853267, 0.846097, 0.839161, 0.839161, 0.839161, 0.839161, 0.839161, 0.839161, 0.839161, 0.839161, 0.839161, 0.839161, 0.833167, 0.833167, 0.827258, 0.827258, 0.827258, 0.827258, 0.821669, 0.821669, 0.821669, 0.810786, 0.810786, 0.810786, 0.799829, 0.799829, 0.799829, 0.794532, 0.794532, 0.794532, 0.794532, 0.789236, 0.789236, 0.789236, 0.789236, 0.784077, 0.784077, 0.784077, 0.784077, 0.779115, 0.779115, 0.779115, 0.774275, 0.774275, 0.769375, 0.769375, 0.764536, 0.759697, 0.754858, 0.754858, 0.750199, 0.750199, 0.745707, 0.745707, 0.745707, 0.741103, 0.736529, 0.736529, 0.736529, 0.727708, 0.727708, 0.727708, 0.727708, 0.723298, 0.723298, 0.723298, 0.723298, 0.723298, 0.718805, 0.718805, 0.718805, 0.718805, 0.718805, 0.714368, 0.714368, 0.714368, 0.714368, 0.709847, 0.709847, 0.709847, 0.705296, 0.705296, 0.705296, 0.696589, 0.692235, 0.692235, 0.692235, 0.692235, 0.687854, 0.683473, 0.683473, 0.679035, 0.679035, 0.679035, 0.670216, 0.665778, 0.665778, 0.665778, 0.661249, 0.661249, 0.661249, 0.656657, 0.638028, 0.633405, 0.628815, 0.628815, 0.619567, 0.610180, 0.596099, 0.596099, 0.591330, 0.586523, 0.581715, 0.576867, 0.576867, 0.572020, 0.567213, 0.562406, 0.552625, 0.552625, 0.552625, 0.552625, 0.547646, 0.542528, 0.532095, 0.526878, 0.521610, 0.516287, 0.516287, 0.510852, 0.499983, 0.489114, 0.483427, 0.477602, 0.477602, 0.465810, 0.459838, 0.459838, 0.454017, 0.454017, 0.448196, 0.442299, 0.442299, 0.442299, 0.436069, 0.436069, 0.436069, 0.436069, 0.436069, 0.436069, 0.436069, 0.410035, 0.390815, 0.390815, 0.384301, 0.377675, 0.370931, 0.370931, 0.370931, 0.370931, 0.370931, 0.370931, 0.370931, 0.355476, 0.355476, 0.355476, 0.355476, 0.355476, 0.347012, 0.338337, 0.338337, 0.320530, 0.311867, 0.311867, 0.302956, 0.294046, 0.284857, 0.284857, 0.275362, 0.265866, 0.246876, 0.237001, 0.237001, 0.226696, 0.226696, 0.215901, 0.215901, 0.205620, 0.205620, 0.205620, 0.205620, 0.205620, 0.191912, 0.191912, 0.177150, 0.162387, 0.147625, 0.147625, 0.147625, 0.132862, 0.132862, 0.116255, 0.099647, 0.099647, 0.074735, 0.024912, 0.024912]) assert_array_almost_equal(y[19:], y_true)
def display_km(dic,field,df): for k in dic.keys(): mask = df[field] == k ti, surv_prob = kaplan_meier_estimator( df["fstat"][mask].values.astype("bool"), df["time"][mask]) plt.step(ti, surv_prob, where="post", label="%s = %s (n = %d)" % (field, dic[k], mask.sum())) plt.ylabel("est. probability of survival $\hat{S}(t)$") plt.xlabel("time $t$") plt.legend(loc="best")
def single_kmc(data, status, interval): sta = [] for i in data[status]: b = bool(i) sta.append(b) interv = data[interval] time, survival_prob = kaplan_meier_estimator(sta, interv) plt.step(time, survival_prob) plt.ylabel("est. probability of survival") plt.xlabel("time $(days)$")
def test_all_uncensored(): time = [1, 2, 2, 3, 7, 6, 5, 5, 3, 9, 11, 23, 17, 13, 6, 13] event = numpy.repeat(True, len(time)) true_x = numpy.array([1, 2, 3, 5, 6, 7, 9, 11, 13, 17, 23]) true_y = numpy.array([0.9375, 0.8125, 0.6875, 0.5625, 0.4375, 0.375, 0.3125, 0.25, 0.125, 0.0625, 0]) x, y = kaplan_meier_estimator(event, time) assert_array_equal(x, true_x) assert_array_almost_equal(y, true_y)
def compare_kmc(data, factor, status, interval): f = data[factor].drop_duplicates().tolist() f.sort() for i in f: group_i = data[data[factor] == i] interv_i = group_i[interval] interv_i.reset_index(drop=True) interv_i = interv_i.tolist() interv_i.append(0) sta_i = [] for j in group_i[status]: c = bool(j) sta_i.append(c) sta_i.append(False) time, survival_prob = kaplan_meier_estimator(sta_i, interv_i) plt.step(time, survival_prob, where='post', label=str(factor) + '=%s' % i) plt.ylabel("est. probability of survival") plt.xlabel("time $(days)$") plt.legend(loc='best') if data[factor].nunique() != 2: print('The factor', factor, 'is non-binary, so the Logrank statistics is not calculated') else: f = data[factor].drop_duplicates().tolist() f.sort() time = [] censor = [] for i in f: interv_i = [] group_i = df[df[factor] == i] for j in group_i[interval]: interv_i.append(j) time.append(interv_i) censor_i = [] for k in group_i[status]: censor_i.append(k) censor.append(censor_i) T = time[0] T1 = time[1] E = censor[0] E1 = censor[1] results = logrank_test(T, T1, E, E1) results.print_summary()
def product_limit_estimator(self): """ Gets the product limit estimator over the score. :return: product limit estimator score index, cumulative probability of positive label """ x_train, y_train, scores = self.truncate() y_train = y_train.astype(bool) # Calculate the product limit estimator score, positive_prob = kaplan_meier_estimator(y_train, scores) return score, positive_prob
def test_first_censored(): time = [1, 2, 2, 3, 7, 6, 5, 5, 3, 9, 11, 13, 17, 13, 6, 23] event = numpy.repeat(True, len(time)) event[0] = False true_x = numpy.array([1, 2, 3, 5, 6, 7, 9, 11, 13, 17, 23]) true_y = numpy.array([1, 0.866666666666667, 0.733333333333333, 0.6, 0.466666666666667, 0.4, 0.333333333333333, 0.266666666666667, 0.133333333333333, 0.0666666666666667, 0]) x, y = kaplan_meier_estimator(event, time) assert_array_equal(x, true_x) assert_array_almost_equal(y, true_y)
def test_simple(simple_data_km): time, event, true_x, true_y = simple_data_km x, y = kaplan_meier_estimator(event, time) assert_array_equal(x, true_x) assert_array_almost_equal(y, true_y) ys = Surv.from_arrays(event, time) est = SurvivalFunctionEstimator().fit(ys) assert_array_equal(est.unique_time_[1:], true_x) assert_array_almost_equal(est.prob_[1:], true_y) prob = est.predict_proba(true_x) assert_array_almost_equal(prob, true_y)
def test_right_truncated_adults(make_aids): event, time_enter, time_exit = make_aids('adults') x, y = kaplan_meier_estimator(event, -time_exit.values, -time_enter.values) true_x = numpy.array( [0.25, 0.50, 0.75, 1.00, 1.25, 1.50, 1.75, 2.00, 2.25, 2.50, 2.75, 3.00, 3.25, 3.50, 3.75, 4.00, 4.25, 4.50, 4.75, 5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7, 7.25, 7.75, 8]) true_y = numpy.array( [0, 0.003, 0.004, 0.01, 0.02, 0.03, 0.05, 0.06, 0.07, 0.09, 0.11, 0.13, 0.16, 0.18, 0.20, 0.21, 0.25, 0.29, 0.31, 0.34, 0.40, 0.49, 0.54, 0.58, 0.61, 0.64, 0.73, 0.8, 0.8, 1, 1]) assert_array_almost_equal(-x[::-1], true_x, 2) assert_array_almost_equal(y[::-1], true_y, 2)
def test_right_truncated_children(make_aids): event, time_enter, time_exit = make_aids('children') x, y = kaplan_meier_estimator(event, -time_exit.values, -time_enter.values) true_x = numpy.array( [0.25, 0.50, 0.75, 1.00, 1.25, 1.50, 1.75, 2.00, 2.25, 2.50, 2.75, 3.00, 3.25, 3.50, 3.75, 4.00, 4.25, 4.50, 5, 5.25, 5.50, 5.75, 6.50, 7]) true_y = numpy.array( [0, 0.02, 0.09, 0.16, 0.22, 0.25, 0.30, 0.35, 0.35, 0.39, 0.43, 0.46, 0.52, 0.56, 0.61, 0.61, 0.61, 0.67, 0.67, 0.67, 0.67, 1.00, 1.00, 1.00]) assert_array_almost_equal(-x[::-1], true_x, 2) assert_array_almost_equal(y[::-1], true_y, 2)
def test_right_truncated_adults(): data = pandas.read_csv(AIDS_ADULTS_FILE, comment="#") event = numpy.repeat(True, data.shape[0]) time_enter = 8 - data["INF"] time_exit = data["DIAG"] x, y = kaplan_meier_estimator(event, -time_exit.values, -time_enter.values) true_x = numpy.array( [0.25, 0.50, 0.75, 1.00, 1.25, 1.50, 1.75, 2.00, 2.25, 2.50, 2.75, 3.00, 3.25, 3.50, 3.75, 4.00, 4.25, 4.50, 4.75, 5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7, 7.25, 7.75, 8]) true_y = numpy.array( [0, 0.003, 0.004, 0.01, 0.02, 0.03, 0.05, 0.06, 0.07, 0.09, 0.11, 0.13, 0.16, 0.18, 0.20, 0.21, 0.25, 0.29, 0.31, 0.34, 0.40, 0.49, 0.54, 0.58, 0.61, 0.64, 0.73, 0.8, 0.8, 1, 1]) assert_array_almost_equal(-x[::-1], true_x, 2) assert_array_almost_equal(y[::-1], true_y, 2)
def test_truncated_male(make_channing): time_enter_m, time_exit_m, event_m = make_channing('Male') x, y = kaplan_meier_estimator(event_m, time_exit_m, time_enter_m) x_true = numpy.array( [751, 759, 777, 781, 782, 806, 817, 820, 821, 823, 830, 835, 836, 837, 843, 846, 847, 852, 853, 854, 856, 863, 865, 866, 869, 871, 872, 875, 876, 878, 879, 883, 885, 886, 890, 891, 893, 894, 895, 898, 900, 906, 907, 909, 911, 914, 915, 919, 921, 923, 925, 926, 927, 932, 936, 938, 940, 943, 945, 946, 948, 951, 953, 955, 956, 957, 959, 960, 962, 964, 966, 967, 969, 970, 971, 972, 973, 977, 978, 981, 982, 983, 984, 985, 988, 989, 993, 996, 998, 1001, 1002, 1005, 1006, 1007, 1009, 1010, 1012, 1013, 1015, 1016, 1018, 1020, 1021, 1022, 1023, 1025, 1027, 1029, 1031, 1033, 1036, 1039, 1041, 1043, 1044, 1045, 1046, 1047, 1051, 1053, 1055, 1058, 1059, 1060, 1063, 1064, 1070, 1073, 1080, 1085, 1093, 1094, 1106, 1107, 1118, 1128, 1139, 1153]) assert_array_equal(x, x_true) assert_array_equal(y[:3], numpy.array([1., 1., .5])) assert (y[3:] == 0).all()
def test_right_truncated_children(): data = pandas.read_csv(AIDS_CHILDREN_FILE, comment="#") event = numpy.repeat(True, data.shape[0]) time_enter = 8 - data["INF"] time_exit = data["DIAG"] x, y = kaplan_meier_estimator(event, -time_exit.values, -time_enter.values) true_x = numpy.array( [0.25, 0.50, 0.75, 1.00, 1.25, 1.50, 1.75, 2.00, 2.25, 2.50, 2.75, 3.00, 3.25, 3.50, 3.75, 4.00, 4.25, 4.50, 5, 5.25, 5.50, 5.75, 6.50, 7]) true_y = numpy.array( [0, 0.02, 0.09, 0.16, 0.22, 0.25, 0.30, 0.35, 0.35, 0.39, 0.43, 0.46, 0.52, 0.56, 0.61, 0.61, 0.61, 0.67, 0.67, 0.67, 0.67, 1.00, 1.00, 1.00]) assert_array_almost_equal(-x[::-1], true_x, 2) assert_array_almost_equal(y[::-1], true_y, 2)
def kaplanMeier(self, outfile, extension="png"): self.upgradeInfo("Generating Kaplan-Meier plot") from sksurv.nonparametric import kaplan_meier_estimator for datasetName, dataset in self.model.dataset.items(): times, survival_prob = kaplan_meier_estimator( dataset["tags"]["Status"], dataset["tags"]["Time_in_days"]) _, ax = plt.subplots() ax.step(times, survival_prob, where="post") ax.set_ylabel("Est. probability of survival") ax.set_xlabel("Days") ax.set_title(f"Kaplan-Meier curve {datasetName} dataset") plt.savefig(f"{outfile}_{datasetName}.{extension}", dpi=100, bbox_inches="tight") plt.close()
from sksurv.datasets import load_veterans_lung_cancer import pandas as pd import numpy as np import matplotlib.pyplot as plt from sksurv.nonparametric import kaplan_meier_estimator import os data = pd.read_csv(os.path.join('../..', '..', 'data', 'tidy_Stroke_Vital_Sign.csv')) data_x = data.drop(['UID', 'Hospital_ID', 'SurvivalWeeks', 'admission_date', 'discharge_date', 'death_date'], axis=1) # data_x = data[['Smoking']] data_y = data[['Mortality', 'SurvivalWeeks']] data_y['Mortality'] = data_y['Mortality'].astype(bool) # KM-All survival time, survival_prob = kaplan_meier_estimator(data_y['Mortality'], data_y['SurvivalWeeks']) plt.step(time, survival_prob, where="post") plt.ylabel("est. probability of survival $\hat{S}(t)$") plt.xlabel("time $t$") plt.show() # KM-AF survival data_af = data_y[data_x.AF == 1] data_non_af = data_y[data_x.AF == 0] af_time, af_survival_prob = kaplan_meier_estimator(data_af['Mortality'], data_af['SurvivalWeeks']) plt.step(af_time, af_survival_prob, where="post", label="AF") non_af_time, non_af_survival_prob = kaplan_meier_estimator(data_non_af['Mortality'], data_non_af['SurvivalWeeks']) plt.step(non_af_time, non_af_survival_prob, where="post", label="Non AF") plt.ylabel("est. probability of survival $\hat{S}(t)$") plt.xlabel("time $t$") plt.legend(loc="best")
def test_whas500(make_whas500): whas500 = make_whas500(with_mean=False, with_std=False) time = whas500.y['lenfol'] event = whas500.y['fstat'] x, y = kaplan_meier_estimator(event, time) true_x = numpy.array([ 1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 16, 17, 18, 19, 20, 22, 26, 31, 32, 33, 34, 37, 42, 46, 49, 52, 53, 55, 57, 60, 61, 62, 64, 69, 76, 81, 83, 88, 91, 93, 95, 97, 100, 101, 108, 109, 113, 116, 117, 118, 129, 132, 134, 135, 137, 140, 143, 145, 146, 151, 166, 169, 187, 192, 197, 200, 226, 233, 235, 259, 269, 274, 287, 289, 295, 297, 312, 313, 321, 328, 343, 345, 354, 358, 359, 363, 368, 371, 373, 376, 382, 385, 386, 390, 392, 397, 398, 399, 400, 403, 405, 406, 407, 408, 411, 412, 416, 418, 419, 421, 422, 424, 426, 427, 433, 437, 440, 442, 445, 446, 449, 450, 451, 452, 457, 458, 459, 465, 466, 467, 473, 475, 478, 479, 480, 486, 497, 506, 507, 510, 511, 516, 519, 521, 522, 523, 524, 529, 530, 532, 535, 537, 542, 544, 550, 551, 552, 554, 559, 562, 568, 570, 573, 578, 587, 589, 606, 609, 612, 614, 626, 631, 632, 644, 646, 649, 654, 659, 662, 670, 673, 675, 704, 714, 718, 725, 849, 865, 903, 905, 920, 936, 953, 1048, 1054, 1065, 1096, 1098, 1102, 1103, 1105, 1106, 1107, 1108, 1109, 1114, 1117, 1121, 1123, 1125, 1126, 1136, 1140, 1150, 1151, 1152, 1157, 1159, 1160, 1161, 1162, 1163, 1165, 1169, 1170, 1174, 1178, 1182, 1187, 1189, 1190, 1191, 1196, 1199, 1200, 1203, 1207, 1211, 1217, 1223, 1224, 1231, 1232, 1233, 1234, 1235, 1244, 1245, 1248, 1251, 1253, 1256, 1257, 1262, 1265, 1266, 1272, 1273, 1274, 1277, 1279, 1280, 1290, 1295, 1298, 1302, 1308, 1314, 1317, 1319, 1320, 1325, 1329, 1332, 1333, 1336, 1338, 1346, 1347, 1353, 1359, 1363, 1365, 1366, 1374, 1377, 1378, 1381, 1384, 1385, 1388, 1390, 1400, 1408, 1409, 1420, 1430, 1433, 1438, 1444, 1449, 1451, 1454, 1456, 1458, 1496, 1506, 1527, 1536, 1548, 1553, 1576, 1577, 1579, 1624, 1627, 1671, 1831, 1836, 1847, 1854, 1858, 1863, 1880, 1883, 1885, 1887, 1889, 1893, 1899, 1904, 1914, 1919, 1920, 1923, 1926, 1931, 1933, 1934, 1936, 1939, 1940, 1941, 1942, 1954, 1955, 1964, 1969, 1976, 1977, 1979, 1993, 1994, 2006, 2009, 2025, 2032, 2048, 2057, 2061, 2064, 2065, 2066, 2083, 2084, 2086, 2100, 2108, 2113, 2114, 2118, 2122, 2123, 2125, 2126, 2131, 2132, 2139, 2145, 2146, 2151, 2152, 2156, 2160, 2166, 2168, 2172, 2173, 2175, 2178, 2190, 2192, 2350, 2353, 2358 ]) assert_array_equal(x.astype(numpy.int_), true_x) true_y = numpy.array([ 0.984, 0.968, 0.962, 0.958, 0.954, 0.944, 0.932, 0.926, 0.918, 0.914, 0.912, 0.908, 0.902, 0.896, 0.892, 0.888, 0.886, 0.884, 0.88, 0.874, 0.872, 0.87, 0.868, 0.866, 0.864, 0.862, 0.86, 0.858, 0.854, 0.852, 0.85, 0.848, 0.844, 0.84, 0.838, 0.836, 0.834, 0.832, 0.83, 0.828, 0.826, 0.824, 0.822, 0.82, 0.818, 0.816, 0.814, 0.812, 0.81, 0.808, 0.806, 0.804, 0.802, 0.8, 0.798, 0.794, 0.792, 0.79, 0.788, 0.786, 0.784, 0.78, 0.776, 0.774, 0.772, 0.77, 0.768, 0.766, 0.764, 0.76, 0.758, 0.756, 0.754, 0.752, 0.75, 0.746, 0.744, 0.742, 0.74, 0.738, 0.736, 0.734, 0.732, 0.73, 0.726, 0.724, 0.724, 0.724, 0.724, 0.724, 0.721960563380282, 0.719921126760564, 0.719921126760564, 0.719921126760564, 0.717864209255533, 0.715807291750503, 0.715807291750503, 0.715807291750503, 0.715807291750503, 0.715807291750503, 0.713714287973455, 0.711621284196407, 0.711621284196407, 0.711621284196407, 0.711621284196407, 0.711621284196407, 0.711621284196407, 0.711621284196407, 0.709484283342964, 0.709484283342964, 0.70734082629359, 0.70734082629359, 0.70734082629359, 0.70734082629359, 0.70734082629359, 0.70734082629359, 0.70734082629359, 0.705150916614662, 0.705150916614662, 0.702926465773606, 0.702926465773606, 0.702926465773606, 0.702926465773606, 0.702926465773606, 0.702926465773606, 0.702926465773606, 0.702926465773606, 0.700621788836644, 0.700621788836644, 0.69830950570517, 0.695997222573696, 0.695997222573696, 0.695997222573696, 0.693669472665422, 0.693669472665422, 0.693669472665422, 0.691325994717228, 0.691325994717228, 0.691325994717228, 0.691325994717228, 0.691325994717228, 0.691325994717228, 0.691325994717228, 0.691325994717228, 0.691325994717228, 0.691325994717228, 0.691325994717228, 0.691325994717228, 0.688865759860583, 0.688865759860583, 0.686387825472596, 0.683900913061463, 0.68141400065033, 0.68141400065033, 0.68141400065033, 0.68141400065033, 0.678871411095665, 0.678871411095665, 0.676319262933651, 0.673767114771638, 0.673767114771638, 0.673767114771638, 0.673767114771638, 0.673767114771638, 0.673767114771638, 0.673767114771638, 0.673767114771638, 0.673767114771638, 0.671114488335529, 0.66846186189942, 0.66846186189942, 0.66846186189942, 0.665788014451822, 0.663114167004225, 0.660440319556627, 0.657766472109029, 0.655092624661431, 0.655092624661431, 0.655092624661431, 0.652396770238956, 0.649700915816481, 0.649700915816481, 0.646993828667246, 0.644286741518011, 0.641579654368776, 0.641579654368776, 0.638861096511281, 0.636142538653786, 0.633423980796291, 0.630705422938796, 0.627986865081301, 0.625268307223807, 0.622549749366312, 0.619831191508817, 0.617112633651322, 0.614394075793827, 0.611675517936333, 0.611675517936333, 0.611675517936333, 0.611675517936333, 0.611675517936333, 0.611675517936333, 0.611675517936333, 0.611675517936333, 0.611675517936333, 0.611675517936333, 0.611675517936333, 0.611675517936333, 0.611675517936333, 0.611675517936333, 0.611675517936333, 0.608734770253946, 0.608734770253946, 0.608734770253946, 0.608734770253946, 0.605721231787343, 0.605721231787343, 0.602692625628406, 0.602692625628406, 0.602692625628406, 0.602692625628406, 0.602692625628406, 0.599601894214927, 0.599601894214927, 0.599601894214927, 0.596478967682557, 0.596478967682557, 0.596478967682557, 0.596478967682557, 0.596478967682557, 0.596478967682557, 0.596478967682557, 0.596478967682557, 0.596478967682557, 0.593183503772709, 0.593183503772709, 0.593183503772709, 0.593183503772709, 0.589832184542355, 0.589832184542355, 0.589832184542355, 0.589832184542355, 0.586422749949625, 0.582973204361686, 0.582973204361686, 0.582973204361686, 0.582973204361686, 0.582973204361686, 0.582973204361686, 0.582973204361686, 0.582973204361686, 0.582973204361686, 0.582973204361686, 0.582973204361686, 0.582973204361686, 0.582973204361686, 0.582973204361686, 0.582973204361686, 0.582973204361686, 0.582973204361686, 0.579112454663926, 0.579112454663926, 0.579112454663926, 0.579112454663926, 0.579112454663926, 0.579112454663926, 0.579112454663926, 0.579112454663926, 0.575034197940941, 0.575034197940941, 0.575034197940941, 0.575034197940941, 0.575034197940941, 0.575034197940941, 0.575034197940941, 0.575034197940941, 0.575034197940941, 0.575034197940941, 0.575034197940941, 0.575034197940941, 0.570506369610697, 0.570506369610697, 0.570506369610697, 0.570506369610697, 0.570506369610697, 0.561153806174456, 0.561153806174456, 0.561153806174456, 0.561153806174456, 0.561153806174456, 0.561153806174456, 0.561153806174456, 0.561153806174456, 0.561153806174456, 0.561153806174456, 0.561153806174456, 0.561153806174456, 0.561153806174456, 0.561153806174456, 0.561153806174456, 0.561153806174456, 0.561153806174456, 0.561153806174456, 0.561153806174456, 0.561153806174456, 0.555542268112711, 0.549930730050967, 0.544319191989222, 0.538707653927478, 0.533096115865733, 0.527484577803989, 0.521873039742244, 0.5162615016805, 0.510649963618755, 0.505038425557011, 0.499426887495266, 0.493815349433521, 0.493815349433521, 0.493815349433521, 0.493815349433521, 0.493815349433521, 0.493815349433521, 0.493815349433521, 0.493815349433521, 0.493815349433521, 0.493815349433521, 0.493815349433521, 0.493815349433521, 0.493815349433521, 0.493815349433521, 0.493815349433521, 0.493815349433521, 0.493815349433521, 0.493815349433521, 0.493815349433521, 0.486444971083767, 0.486444971083767, 0.486444971083767, 0.486444971083767, 0.486444971083767, 0.486444971083767, 0.486444971083767, 0.486444971083767, 0.486444971083767, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.438084944718714, 0.438084944718714, 0.438084944718714, 0.438084944718714, 0.438084944718714, 0.438084944718714, 0.438084944718714, 0.438084944718714, 0.438084944718714, 0.292056629812476, 0.146028314906238, 0 ]) assert_array_almost_equal(y, true_y)
def mmm(a, df=None): if type(a) != list: print("For ", a.name, "mean: ", str(a.mean()), " median: ", str(a.median())) else: for a1 in a: mmm(df[a1]) # In[ ]: ## Using the survival prediction models time, survival_prob = kaplan_meier_estimator( [bool(x) for x in data[data['smart_187_raw'] > 10]['failure'].values], data[data['smart_187_raw'] > 10]['smart_9_raw'].values / 24.) plt.step(time, survival_prob, where="post", label='Errors >10') crit_features = [5, 10, 184, 187, 188, 196, 197, 198, 201] crit_names = ['smart_' + str(x) + "_raw" for x in crit_features] mmm(crit_names, df=data) mmm(smart_features_raw, df=data) ## Narrowing down to the columns that matter! smart_features_raw = [x for x in data.columns if "smart" in x and "raw" in x] ## From wiki, most important features for i in [
risk_score = predictor.predict(x_test) high_risk_masks.append( risk_score > np.median(risk_score_train)) y_tests.append(y_test) except Exception as e: logger.warning("Error {}".format(str(e))) c_indexes.append(np.NaN) # ----------------------- Kaplan-Meier -------------------------------- high_risk_mask = np.concatenate(high_risk_masks) y_tests = np.concatenate(y_tests) y_high_risk, y_low_risk = y_tests[high_risk_mask], y_tests[ ~high_risk_mask] km_high_time, km_high_prob = kaplan_meier_estimator( y_high_risk['event'], y_high_risk['time']) km_low_time, km_low_prob = kaplan_meier_estimator( y_low_risk['event'], y_low_risk['time']) km_ests = [[km_high_time.tolist(), km_high_prob.tolist()], [km_low_time.tolist(), km_low_prob.tolist()]] p_vals = logrank_test(y_tests[['event', 'time']], high_risk_mask) dict_append(results, c_indexes, anno, 'c-index', mode) dict_append(results, p_vals, anno, 'kaplan-meier_p-value', mode) dict_append(results, km_ests, anno, 'kaplan-meier_estimate', mode)
import pandas as pd import matplotlib.pyplot as plt from sksurv.nonparametric import kaplan_meier_estimator path = 'csv/PC200_agesmr_eval.csv' data = pd.read_csv(path, index_col='SID', delimiter=',') c = {} for i in range(3): for j in range(3): key = str(i + 1) + str(j + 1) c[key] = data[(data.メンテナンスCID == i + 1) & (data.使われ方CID == j + 1)] figure = plt.figure() time, survival_prob = kaplan_meier_estimator( list(map(lambda i: i == 1, c[key]["FSTAT"])), c[key]["SMR"]) print(key + str(survival_prob)) plt.step(time, survival_prob, where="post") plt.savefig(key + '.png')
""" from sksurv.datasets import load_veterans_lung_cancer data_x, data_y = load_veterans_lung_cancer() data_y import pandas as pd pd.DataFrame.from_records(data_y[[11, 5, 32, 13, 23]], index=range(1, 6)) %matplotlib inline import matplotlib.pyplot as plt from sksurv.nonparametric import kaplan_meier_estimator time, survival_prob = kaplan_meier_estimator(data_y["Status"], data_y["Survival_in_days"]) plt.step(time, survival_prob, where="post") plt.ylabel("est. probability of survival $\hat{S}(t)$") plt.xlabel("time $t$") data_x["Treatment"].value_counts() for treatment_type in ("standard", "test"): mask_treat = data_x["Treatment"] == treatment_type time_treatment, survival_prob_treatment = kaplan_meier_estimator( data_y["Status"][mask_treat], data_y["Survival_in_days"][mask_treat]) plt.step(time_treatment, survival_prob_treatment, where="post", label="Treatment = %s" % treatment_type)
# # #### scikit-survival # %% # #!conda install -c sebp scikit-survival import matplotlib.pyplot as plt import matplotlib as mpl plt.style.use('seaborn-notebook') mpl.rc('font', **{'family': 'serif', 'serif': 'Palatino'}) from sksurv.nonparametric import kaplan_meier_estimator os_event = dat['OS event'].map(lambda x: bool(x)) for status in dat['ER Status'].unique(): mask = dat['ER Status']== status time_status, surv_prob_status = kaplan_meier_estimator(os_event[mask], dat['OS Time'][mask]) plt.step(time_status, surv_prob_status, where='post', label = "%s (n = %d)" % (status, mask.sum())) plt.ylabel('Estimated probability of survival $\hat{S}(t)$') plt.xlabel('Time $t$') plt.legend(loc='best'); # %% [markdown] # #### lifelines # %% from lifelines import KaplanMeierFitter from lifelines.plotting import add_at_risk_counts