Exemplo n.º 1
0
    def test_compare_clinical_kernel(self):
        x_full, y = load_whas500()

        trans = ClinicalKernelTransform()
        trans.fit(x_full)

        kpca = KernelPCA(kernel=trans.pairwise_kernel, copy_X=True)
        xt = kpca.fit_transform(self.x)

        nrsvm = FastSurvivalSVM(optimizer='rbtree',
                                tol=1e-8,
                                max_iter=500,
                                random_state=0)
        nrsvm.fit(xt, y)

        rsvm = FastKernelSurvivalSVM(optimizer='rbtree',
                                     kernel=trans.pairwise_kernel,
                                     tol=1e-8,
                                     max_iter=500,
                                     random_state=0)
        rsvm.fit(self.x.values, y)

        pred_nrsvm = nrsvm.predict(kpca.transform(self.x))
        pred_rsvm = rsvm.predict(self.x.values)

        self.assertEqual(len(pred_nrsvm), len(pred_rsvm))

        c1 = concordance_index_censored(y['fstat'], y['lenfol'], pred_nrsvm)
        c2 = concordance_index_censored(y['fstat'], y['lenfol'], pred_rsvm)

        self.assertAlmostEqual(c1[0], c2[0])
        self.assertTupleEqual(c1[1:], c2[1:])
Exemplo n.º 2
0
    def setUp(self):
        x, self.y = load_whas500()

        x = column.categorical_to_numeric(column.standardize(x,
                                                             with_std=False))
        self.x = x.values
        self.columns = x.columns.tolist()
Exemplo n.º 3
0
    def test_compare_rbf(self):
        x, y = load_whas500()
        x = encode_categorical(standardize(x))

        kpca = KernelPCA(kernel="rbf")
        xt = kpca.fit_transform(x)

        nrsvm = FastSurvivalSVM(optimizer='rbtree',
                                tol=1e-8,
                                max_iter=1000,
                                random_state=0)
        nrsvm.fit(xt, y)

        rsvm = FastKernelSurvivalSVM(optimizer='rbtree',
                                     kernel="rbf",
                                     tol=1e-8,
                                     max_iter=1000,
                                     random_state=0)
        rsvm.fit(x, y)

        pred_nrsvm = nrsvm.predict(kpca.transform(x))
        pred_rsvm = rsvm.predict(x)

        self.assertEqual(len(pred_nrsvm), len(pred_rsvm))

        c1 = concordance_index_censored(y['fstat'], y['lenfol'], pred_nrsvm)
        c2 = concordance_index_censored(y['fstat'], y['lenfol'], pred_rsvm)

        self.assertAlmostEqual(c1[0], c2[0])
        self.assertTupleEqual(c1[1:], c2[1:])
Exemplo n.º 4
0
 def _make_whas500(with_mean=True, with_std=True, to_numeric=False):
     x, y = load_whas500()
     if with_mean:
         x = standardize(x, with_std=with_std)
     if to_numeric:
         x = categorical_to_numeric(x)
     names = ['(Intercept)'] + x.columns.tolist()
     return DataSetWithNames(x=x.values, y=y, names=names, x_data_frame=x)
def test_pandas_inputs(estimator_cls):
    X, y = load_whas500()
    X = X.iloc[:50]
    y = y[:50]
    X = X.loc[:, ["age", "bmi", "chf", "gender"]]

    estimator = estimator_cls()
    if "kernel" in estimator.get_params():
        estimator.set_params(kernel="rbf")
    estimator.fit(X, y)
    estimator.predict(X)
Exemplo n.º 6
0
def whas500_sparse_data():
    x, y = load_whas500()
    x_dense = categorical_to_numeric(x.select_dtypes(exclude=[numpy.float_]))

    data = []
    index_i = []
    index_j = []
    for j, (_, col) in enumerate(x_dense.iteritems()):
        idx = numpy.flatnonzero(col.values)
        data.extend([1] * len(idx))
        index_i.extend(idx)
        index_j.extend([j] * len(idx))

    x_sparse = coo_matrix((data, (index_i, index_j)))
    return SparseDataSet(x_dense=x_dense, x_sparse=x_sparse, y=y)
Exemplo n.º 7
0
    def setUp(self):
        x, self.y = load_whas500()
        self.x_dense = column.categorical_to_numeric(x.select_dtypes(exclude=[numpy.float_]))

        data = []
        index_i = []
        index_j = []
        for j, (_, col) in enumerate(self.x_dense.iteritems()):
            idx = numpy.flatnonzero(col.values)
            data.extend([1] * len(idx))
            index_i.extend(idx)
            index_j.extend([j] * len(idx))

        self.x_sparse = coo_matrix((data, (index_i, index_j)))
        assert_array_equal(self.x_dense.values, self.x_sparse.toarray())
    def test_fit_and_predict_clinical_kernel(self):
        x_full, y = load_whas500()

        trans = ClinicalKernelTransform()
        trans.fit(x_full)
        x = self.x

        ssvm = FastKernelSurvivalSVM(optimizer="rbtree", kernel=trans.pairwise_kernel,
                                     tol=7e-7, max_iter=100, random_state=0)
        ssvm.fit(x.values, y)

        self.assertFalse(ssvm._pairwise)
        self.assertEquals(x.shape[0], ssvm.coef_.shape[0])

        c = ssvm.score(x.values, y)
        self.assertGreaterEqual(c, 0.854)
Exemplo n.º 9
0
    def test_fit_and_predict_clinical_kernel(self):
        x_full, y = load_whas500()

        trans = ClinicalKernelTransform()
        trans.fit(x_full)

        x = encode_categorical(standardize(x_full))

        ssvm = FastKernelSurvivalSVM(optimizer="rbtree",
                                     kernel=trans.pairwise_kernel,
                                     max_iter=100,
                                     random_state=0)
        ssvm.fit(x.values, y)

        self.assertFalse(ssvm._pairwise)
        self.assertEquals(x.shape[0], ssvm.coef_.shape[0])

        c = ssvm.score(x.values, y)
        self.assertLessEqual(abs(0.83699051218246412 - c), 1e-3)
Exemplo n.º 10
0
def whas500_with_ties():
    # naive survival SVM does resolve ties in survival time differently,
    # therefore use data without ties
    x, y = load_whas500()
    x = normalize(encode_categorical(x))
    return x, y
Exemplo n.º 11
0
 def setUp(self):
     x, self.y = load_whas500()
     self.x = encode_categorical(standardize(x))
Exemplo n.º 12
0
 def test_load_whas500():
     x, y = sdata.load_whas500()
     assert x.shape == (500, 14)
     assert y.shape == (500,)
     assert_structured_array_dtype(y, 'fstat', 'lenfol', 215)
Exemplo n.º 13
0
 def setUp(self):
     x, self.y, = load_whas500()
     self.x = standardize(x)
Exemplo n.º 14
0
def test_nelson_aalen_whas500():
    _, y = load_whas500()
    time = y['lenfol']
    event = y['fstat']

    x, y = nelson_aalen_estimator(event, time)

    true_x = numpy.array(
        [1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 16, 17, 18, 19, 20, 22, 26, 31, 32, 33, 34, 37, 42, 46, 49, 52, 53, 55,
         57, 60, 61, 62, 64, 69, 76, 81, 83, 88, 91, 93, 95, 97, 100, 101, 108, 109, 113, 116, 117, 118, 129, 132,
         134, 135, 137, 140, 143, 145, 146, 151, 166, 169, 187, 192, 197, 200, 226, 233, 235, 259, 269, 274, 287,
         289, 295, 297, 312, 313, 321, 328, 343, 345, 354, 358, 359, 363, 368, 371, 373, 376, 382, 385, 386, 390,
         392, 397, 398, 399, 400, 403, 405, 406, 407, 408, 411, 412, 416, 418, 419, 421, 422, 424, 426, 427, 433,
         437, 440, 442, 445, 446, 449, 450, 451, 452, 457, 458, 459, 465, 466, 467, 473, 475, 478, 479, 480, 486,
         497, 506, 507, 510, 511, 516, 519, 521, 522, 523, 524, 529, 530, 532, 535, 537, 542, 544, 550, 551, 552,
         554, 559, 562, 568, 570, 573, 578, 587, 589, 606, 609, 612, 614, 626, 631, 632, 644, 646, 649, 654, 659,
         662, 670, 673, 675, 704, 714, 718, 725, 849, 865, 903, 905, 920, 936, 953, 1048, 1054, 1065, 1096, 1098,
         1102, 1103, 1105, 1106, 1107, 1108, 1109, 1114, 1117, 1121, 1123, 1125, 1126, 1136, 1140, 1150, 1151, 1152,
         1157, 1159, 1160, 1161, 1162, 1163, 1165, 1169, 1170, 1174, 1178, 1182, 1187, 1189, 1190, 1191, 1196, 1199,
         1200, 1203, 1207, 1211, 1217, 1223, 1224, 1231, 1232, 1233, 1234, 1235, 1244, 1245, 1248, 1251, 1253, 1256,
         1257, 1262, 1265, 1266, 1272, 1273, 1274, 1277, 1279, 1280, 1290, 1295, 1298, 1302, 1308, 1314, 1317, 1319,
         1320, 1325, 1329, 1332, 1333, 1336, 1338, 1346, 1347, 1353, 1359, 1363, 1365, 1366, 1374, 1377, 1378, 1381,
         1384, 1385, 1388, 1390, 1400, 1408, 1409, 1420, 1430, 1433, 1438, 1444, 1449, 1451, 1454, 1456, 1458, 1496,
         1506, 1527, 1536, 1548, 1553, 1576, 1577, 1579, 1624, 1627, 1671, 1831, 1836, 1847, 1854, 1858, 1863, 1880,
         1883, 1885, 1887, 1889, 1893, 1899, 1904, 1914, 1919, 1920, 1923, 1926, 1931, 1933, 1934, 1936, 1939, 1940,
         1941, 1942, 1954, 1955, 1964, 1969, 1976, 1977, 1979, 1993, 1994, 2006, 2009, 2025, 2032, 2048, 2057, 2061,
         2064, 2065, 2066, 2083, 2084, 2086, 2100, 2108, 2113, 2114, 2118, 2122, 2123, 2125, 2126, 2131, 2132, 2139,
         2145, 2146, 2151, 2152, 2156, 2160, 2166, 2168, 2172, 2173, 2175, 2178, 2190, 2192, 2350, 2353, 2358])

    assert_array_equal(x.astype(numpy.int_), true_x)

    true_y = numpy.array(
        [0.016, 0.032260162601626, 0.038458509709064, 0.0426165138670682, 0.0467918792115359, 0.0572740595050369,
         0.0699859239118166, 0.0764236921521598, 0.0850630010074515, 0.089420299482397, 0.0916084832898368,
         0.0959944482021175, 0.102602377717536, 0.109254262418201, 0.113718548132487, 0.118202853065222,
         0.120455105317475, 0.12271244166059, 0.127237328538418, 0.1340555103566, 0.13634383987605,
         0.138637417857702, 0.140936268432414, 0.143240415897852, 0.145549884720023, 0.147864699534838,
         0.150184885149687, 0.152510466545036, 0.157172471207041, 0.159514391581748, 0.161861809422123,
         0.164214750598594, 0.16893173173067, 0.173671068223561, 0.176052020604513, 0.178438655449382,
         0.180830999946989, 0.183229081481762, 0.185632927635608, 0.188042566189824, 0.190458025127023,
         0.192879332633076, 0.195306517099095, 0.197739607123426, 0.20017863151367, 0.202623619288731,
         0.205074599680888, 0.20753160213789, 0.209994656325083, 0.212463792127552, 0.214939039652304,
         0.217420429230468, 0.219907991419523, 0.222401757005558, 0.224901757005558, 0.229914288333878,
         0.232433180021536, 0.234958432546788, 0.237490078116409, 0.240028149182399, 0.242572678444485,
         0.247674719260812, 0.252802924389017, 0.255380243976646, 0.257964223304811, 0.26055489687994,
         0.263152299477343, 0.26575646614401, 0.268367432201451, 0.273603034295692, 0.27623461324306,
         0.278873135670501, 0.281518638316003, 0.284171158209903, 0.286830732677988, 0.292164066011321,
         0.294845031158774, 0.297533203201785, 0.300228620991542, 0.302931323694245, 0.305641350794516,
         0.308358742098864, 0.311083537739191, 0.313815778176349, 0.319295230231144, 0.322050051167783,
         0.322050051167783, 0.322050051167783, 0.322050051167783, 0.322050051167783, 0.324866952576234,
         0.327691811333296, 0.327691811333296, 0.327691811333296, 0.330548954190439, 0.333414283703333,
         0.333414283703333, 0.333414283703333, 0.333414283703333, 0.333414283703333, 0.33633826031152,
         0.339270811631168, 0.339270811631168, 0.339270811631168, 0.339270811631168, 0.339270811631168,
         0.339270811631168, 0.339270811631168, 0.342273814634171, 0.342273814634171, 0.345294962670425,
         0.345294962670425, 0.345294962670425, 0.345294962670425, 0.345294962670425, 0.345294962670425,
         0.345294962670425, 0.348390937902623, 0.348390937902623, 0.351545512035115, 0.351545512035115,
         0.351545512035115, 0.351545512035115, 0.351545512035115, 0.351545512035115, 0.351545512035115,
         0.351545512035115, 0.354824200559705, 0.354824200559705, 0.358124530592708, 0.361435788870854,
         0.361435788870854, 0.361435788870854, 0.364780270476205, 0.364780270476205, 0.364780270476205,
         0.368158648854584, 0.368158648854584, 0.368158648854584, 0.368158648854584, 0.368158648854584,
         0.368158648854584, 0.368158648854584, 0.368158648854584, 0.368158648854584, 0.368158648854584,
         0.368158648854584, 0.368158648854584, 0.371717367715794, 0.371717367715794, 0.375314490017952,
         0.378937678423749, 0.382574042060113, 0.382574042060113, 0.382574042060113, 0.382574042060113,
         0.386305385343695, 0.386305385343695, 0.390064783839935, 0.393838368745596, 0.393838368745596,
         0.393838368745596, 0.393838368745596, 0.393838368745596, 0.393838368745596, 0.393838368745596,
         0.393838368745596, 0.393838368745596, 0.397775376619611, 0.401727945789572, 0.401727945789572,
         0.401727945789572, 0.405727945789572, 0.4097440100466, 0.413776268111116, 0.417824851107068,
         0.421889891757474, 0.421889891757474, 0.421889891757474, 0.426005118094923, 0.430137349499881,
         0.430137349499881, 0.434304016166548, 0.438488116584958, 0.442689797257227, 0.442689797257227,
         0.44692708539282, 0.451182404541756, 0.45545590881526, 0.459747754308823, 0.464058099136409,
         0.468387103465413, 0.47273492955237, 0.477101741779444, 0.481487706691725, 0.485892993035337,
         0.490317771796399, 0.490317771796399, 0.490317771796399, 0.490317771796399, 0.490317771796399,
         0.490317771796399, 0.490317771796399, 0.490317771796399, 0.490317771796399, 0.490317771796399,
         0.490317771796399, 0.490317771796399, 0.490317771796399, 0.490317771796399, 0.490317771796399,
         0.495125464104092, 0.495125464104092, 0.495125464104092, 0.495125464104092, 0.500075959153596,
         0.500075959153596, 0.505075959153596, 0.505075959153596, 0.505075959153596, 0.505075959153596,
         0.505075959153596, 0.510204164281801, 0.510204164281801, 0.510204164281801, 0.515412497615135,
         0.515412497615135, 0.515412497615135, 0.515412497615135, 0.515412497615135, 0.515412497615135,
         0.515412497615135, 0.515412497615135, 0.515412497615135, 0.520937359493588, 0.520937359493588,
         0.520937359493588, 0.520937359493588, 0.526587077007712, 0.526587077007712, 0.526587077007712,
         0.526587077007712, 0.532367423828521, 0.538249776769698, 0.538249776769698, 0.538249776769698,
         0.538249776769698, 0.538249776769698, 0.538249776769698, 0.538249776769698, 0.538249776769698,
         0.538249776769698, 0.538249776769698, 0.538249776769698, 0.538249776769698, 0.538249776769698,
         0.538249776769698, 0.538249776769698, 0.538249776769698, 0.538249776769698, 0.544872293325989,
         0.544872293325989, 0.544872293325989, 0.544872293325989, 0.544872293325989, 0.544872293325989,
         0.544872293325989, 0.544872293325989, 0.551914546847116, 0.551914546847116, 0.551914546847116,
         0.551914546847116, 0.551914546847116, 0.551914546847116, 0.551914546847116, 0.551914546847116,
         0.551914546847116, 0.551914546847116, 0.551914546847116, 0.551914546847116, 0.559788562595148,
         0.559788562595148, 0.559788562595148, 0.559788562595148, 0.559788562595148, 0.576182005218098,
         0.576182005218098, 0.576182005218098, 0.576182005218098, 0.576182005218098, 0.576182005218098,
         0.576182005218098, 0.576182005218098, 0.576182005218098, 0.576182005218098, 0.576182005218098,
         0.576182005218098, 0.576182005218098, 0.576182005218098, 0.576182005218098, 0.576182005218098,
         0.576182005218098, 0.576182005218098, 0.576182005218098, 0.576182005218098, 0.586182005218098,
         0.596283015319109, 0.606487096951762, 0.616796375302277, 0.627213041968944, 0.637739357758417,
         0.648377655630758, 0.659130343802801, 0.669999909020192, 0.680988920009203, 0.692100031120314,
         0.703335986176494, 0.703335986176494, 0.703335986176494, 0.703335986176494, 0.703335986176494,
         0.703335986176494, 0.703335986176494, 0.703335986176494, 0.703335986176494, 0.703335986176494,
         0.703335986176494, 0.703335986176494, 0.703335986176494, 0.703335986176494, 0.703335986176494,
         0.703335986176494, 0.703335986176494, 0.703335986176494, 0.703335986176494, 0.718261359310822,
         0.718261359310822, 0.718261359310822, 0.718261359310822, 0.718261359310822, 0.718261359310822,
         0.718261359310822, 0.718261359310822, 0.718261359310822, 0.735805218959945, 0.735805218959945,
         0.735805218959945, 0.735805218959945, 0.735805218959945, 0.735805218959945, 0.735805218959945,
         0.735805218959945, 0.735805218959945, 0.735805218959945, 0.735805218959945, 0.735805218959945,
         0.735805218959945, 0.735805218959945, 0.735805218959945, 0.735805218959945, 0.735805218959945,
         0.735805218959945, 0.735805218959945, 0.735805218959945, 0.735805218959945, 0.735805218959945,
         0.735805218959945, 0.735805218959945, 0.735805218959945, 0.735805218959945, 0.735805218959945,
         0.735805218959945, 0.735805218959945, 0.735805218959945, 0.735805218959945, 0.735805218959945,
         0.735805218959945, 0.735805218959945, 0.735805218959945, 0.735805218959945, 0.735805218959945,
         0.735805218959945, 0.735805218959945, 0.819138552293278, 0.819138552293278, 0.819138552293278,
         0.819138552293278, 0.819138552293278, 0.819138552293278, 0.819138552293278, 0.819138552293278,
         0.819138552293278, 1.15247188562661, 1.65247188562661, 2.65247188562661])

    assert_array_almost_equal(y, true_y)
Exemplo n.º 15
0
 def setUp(self):
     x, self.y = load_whas500()
     self.x = categorical_to_numeric(x)
Exemplo n.º 16
0
    def test_whas500():
        _, y = load_whas500()
        time = y['lenfol']
        event = y['fstat']

        x, y = kaplan_meier_estimator(event, time)

        true_x = numpy.array(
            [1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 16, 17, 18, 19, 20, 22, 26, 31, 32, 33, 34, 37, 42, 46, 49, 52, 53, 55,
             57, 60, 61, 62, 64, 69, 76, 81, 83, 88, 91, 93, 95, 97, 100, 101, 108, 109, 113, 116, 117, 118, 129, 132,
             134, 135, 137, 140, 143, 145, 146, 151, 166, 169, 187, 192, 197, 200, 226, 233, 235, 259, 269, 274, 287,
             289, 295, 297, 312, 313, 321, 328, 343, 345, 354, 358, 359, 363, 368, 371, 373, 376, 382, 385, 386, 390,
             392, 397, 398, 399, 400, 403, 405, 406, 407, 408, 411, 412, 416, 418, 419, 421, 422, 424, 426, 427, 433,
             437, 440, 442, 445, 446, 449, 450, 451, 452, 457, 458, 459, 465, 466, 467, 473, 475, 478, 479, 480, 486,
             497, 506, 507, 510, 511, 516, 519, 521, 522, 523, 524, 529, 530, 532, 535, 537, 542, 544, 550, 551, 552,
             554, 559, 562, 568, 570, 573, 578, 587, 589, 606, 609, 612, 614, 626, 631, 632, 644, 646, 649, 654, 659,
             662, 670, 673, 675, 704, 714, 718, 725, 849, 865, 903, 905, 920, 936, 953, 1048, 1054, 1065, 1096, 1098,
             1102, 1103, 1105, 1106, 1107, 1108, 1109, 1114, 1117, 1121, 1123, 1125, 1126, 1136, 1140, 1150, 1151, 1152,
             1157, 1159, 1160, 1161, 1162, 1163, 1165, 1169, 1170, 1174, 1178, 1182, 1187, 1189, 1190, 1191, 1196, 1199,
             1200, 1203, 1207, 1211, 1217, 1223, 1224, 1231, 1232, 1233, 1234, 1235, 1244, 1245, 1248, 1251, 1253, 1256,
             1257, 1262, 1265, 1266, 1272, 1273, 1274, 1277, 1279, 1280, 1290, 1295, 1298, 1302, 1308, 1314, 1317, 1319,
             1320, 1325, 1329, 1332, 1333, 1336, 1338, 1346, 1347, 1353, 1359, 1363, 1365, 1366, 1374, 1377, 1378, 1381,
             1384, 1385, 1388, 1390, 1400, 1408, 1409, 1420, 1430, 1433, 1438, 1444, 1449, 1451, 1454, 1456, 1458, 1496,
             1506, 1527, 1536, 1548, 1553, 1576, 1577, 1579, 1624, 1627, 1671, 1831, 1836, 1847, 1854, 1858, 1863, 1880,
             1883, 1885, 1887, 1889, 1893, 1899, 1904, 1914, 1919, 1920, 1923, 1926, 1931, 1933, 1934, 1936, 1939, 1940,
             1941, 1942, 1954, 1955, 1964, 1969, 1976, 1977, 1979, 1993, 1994, 2006, 2009, 2025, 2032, 2048, 2057, 2061,
             2064, 2065, 2066, 2083, 2084, 2086, 2100, 2108, 2113, 2114, 2118, 2122, 2123, 2125, 2126, 2131, 2132, 2139,
             2145, 2146, 2151, 2152, 2156, 2160, 2166, 2168, 2172, 2173, 2175, 2178, 2190, 2192, 2350, 2353, 2358])

        assert_array_equal(x.astype(numpy.int_), true_x)

        true_y = numpy.array(
            [0.984, 0.968, 0.962, 0.958, 0.954, 0.944, 0.932, 0.926, 0.918, 0.914, 0.912, 0.908, 0.902, 0.896, 0.892,
             0.888, 0.886, 0.884, 0.88, 0.874, 0.872, 0.87, 0.868, 0.866, 0.864, 0.862, 0.86, 0.858, 0.854, 0.852, 0.85,
             0.848, 0.844, 0.84, 0.838, 0.836, 0.834, 0.832, 0.83, 0.828, 0.826, 0.824, 0.822, 0.82, 0.818, 0.816,
             0.814, 0.812, 0.81, 0.808, 0.806, 0.804, 0.802, 0.8, 0.798, 0.794, 0.792, 0.79, 0.788, 0.786, 0.784, 0.78,
             0.776, 0.774, 0.772, 0.77, 0.768, 0.766, 0.764, 0.76, 0.758, 0.756, 0.754, 0.752, 0.75, 0.746, 0.744,
             0.742, 0.74, 0.738, 0.736, 0.734, 0.732, 0.73, 0.726, 0.724, 0.724, 0.724, 0.724, 0.724, 0.721960563380282,
             0.719921126760564, 0.719921126760564, 0.719921126760564, 0.717864209255533, 0.715807291750503,
             0.715807291750503, 0.715807291750503, 0.715807291750503, 0.715807291750503, 0.713714287973455,
             0.711621284196407, 0.711621284196407, 0.711621284196407, 0.711621284196407, 0.711621284196407,
             0.711621284196407, 0.711621284196407, 0.709484283342964, 0.709484283342964, 0.70734082629359,
             0.70734082629359, 0.70734082629359, 0.70734082629359, 0.70734082629359, 0.70734082629359, 0.70734082629359,
             0.705150916614662, 0.705150916614662, 0.702926465773606, 0.702926465773606, 0.702926465773606,
             0.702926465773606, 0.702926465773606, 0.702926465773606, 0.702926465773606, 0.702926465773606,
             0.700621788836644, 0.700621788836644, 0.69830950570517, 0.695997222573696, 0.695997222573696,
             0.695997222573696, 0.693669472665422, 0.693669472665422, 0.693669472665422, 0.691325994717228,
             0.691325994717228, 0.691325994717228, 0.691325994717228, 0.691325994717228, 0.691325994717228,
             0.691325994717228, 0.691325994717228, 0.691325994717228, 0.691325994717228, 0.691325994717228,
             0.691325994717228, 0.688865759860583, 0.688865759860583, 0.686387825472596, 0.683900913061463,
             0.68141400065033, 0.68141400065033, 0.68141400065033, 0.68141400065033, 0.678871411095665,
             0.678871411095665, 0.676319262933651, 0.673767114771638, 0.673767114771638, 0.673767114771638,
             0.673767114771638, 0.673767114771638, 0.673767114771638, 0.673767114771638, 0.673767114771638,
             0.673767114771638, 0.671114488335529, 0.66846186189942, 0.66846186189942, 0.66846186189942,
             0.665788014451822, 0.663114167004225, 0.660440319556627, 0.657766472109029, 0.655092624661431,
             0.655092624661431, 0.655092624661431, 0.652396770238956, 0.649700915816481, 0.649700915816481,
             0.646993828667246, 0.644286741518011, 0.641579654368776, 0.641579654368776, 0.638861096511281,
             0.636142538653786, 0.633423980796291, 0.630705422938796, 0.627986865081301, 0.625268307223807,
             0.622549749366312, 0.619831191508817, 0.617112633651322, 0.614394075793827, 0.611675517936333,
             0.611675517936333, 0.611675517936333, 0.611675517936333, 0.611675517936333, 0.611675517936333,
             0.611675517936333, 0.611675517936333, 0.611675517936333, 0.611675517936333, 0.611675517936333,
             0.611675517936333, 0.611675517936333, 0.611675517936333, 0.611675517936333, 0.608734770253946,
             0.608734770253946, 0.608734770253946, 0.608734770253946, 0.605721231787343, 0.605721231787343,
             0.602692625628406, 0.602692625628406, 0.602692625628406, 0.602692625628406, 0.602692625628406,
             0.599601894214927, 0.599601894214927, 0.599601894214927, 0.596478967682557, 0.596478967682557,
             0.596478967682557, 0.596478967682557, 0.596478967682557, 0.596478967682557, 0.596478967682557,
             0.596478967682557, 0.596478967682557, 0.593183503772709, 0.593183503772709, 0.593183503772709,
             0.593183503772709, 0.589832184542355, 0.589832184542355, 0.589832184542355, 0.589832184542355,
             0.586422749949625, 0.582973204361686, 0.582973204361686, 0.582973204361686, 0.582973204361686,
             0.582973204361686, 0.582973204361686, 0.582973204361686, 0.582973204361686, 0.582973204361686,
             0.582973204361686, 0.582973204361686, 0.582973204361686, 0.582973204361686, 0.582973204361686,
             0.582973204361686, 0.582973204361686, 0.582973204361686, 0.579112454663926, 0.579112454663926,
             0.579112454663926, 0.579112454663926, 0.579112454663926, 0.579112454663926, 0.579112454663926,
             0.579112454663926, 0.575034197940941, 0.575034197940941, 0.575034197940941, 0.575034197940941,
             0.575034197940941, 0.575034197940941, 0.575034197940941, 0.575034197940941, 0.575034197940941,
             0.575034197940941, 0.575034197940941, 0.575034197940941, 0.570506369610697, 0.570506369610697,
             0.570506369610697, 0.570506369610697, 0.570506369610697, 0.561153806174456, 0.561153806174456,
             0.561153806174456, 0.561153806174456, 0.561153806174456, 0.561153806174456, 0.561153806174456,
             0.561153806174456, 0.561153806174456, 0.561153806174456, 0.561153806174456, 0.561153806174456,
             0.561153806174456, 0.561153806174456, 0.561153806174456, 0.561153806174456, 0.561153806174456,
             0.561153806174456, 0.561153806174456, 0.561153806174456, 0.555542268112711, 0.549930730050967,
             0.544319191989222, 0.538707653927478, 0.533096115865733, 0.527484577803989, 0.521873039742244,
             0.5162615016805, 0.510649963618755, 0.505038425557011, 0.499426887495266, 0.493815349433521,
             0.493815349433521, 0.493815349433521, 0.493815349433521, 0.493815349433521, 0.493815349433521,
             0.493815349433521, 0.493815349433521, 0.493815349433521, 0.493815349433521, 0.493815349433521,
             0.493815349433521, 0.493815349433521, 0.493815349433521, 0.493815349433521, 0.493815349433521,
             0.493815349433521, 0.493815349433521, 0.493815349433521, 0.486444971083767, 0.486444971083767,
             0.486444971083767, 0.486444971083767, 0.486444971083767, 0.486444971083767, 0.486444971083767,
             0.486444971083767, 0.486444971083767, 0.477910848784052, 0.477910848784052, 0.477910848784052,
             0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052,
             0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052,
             0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052,
             0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052,
             0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052,
             0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052,
             0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052, 0.477910848784052,
             0.477910848784052, 0.438084944718714, 0.438084944718714, 0.438084944718714, 0.438084944718714,
             0.438084944718714, 0.438084944718714, 0.438084944718714, 0.438084944718714, 0.438084944718714,
             0.292056629812476, 0.146028314906238, 0])

        assert_array_almost_equal(y, true_y)
Exemplo n.º 17
0
print("테스트 세트의 정확도 : {:.2f}".format(Cox.score(X_test, y_test)))

# %%
X = df_LR[[
    "정규화_인구", "정규화_교통량_07", "정규화_교통량_15", "정규화_혼잡빈도강도합", "정규화_혼잡시간강도합",
    "정규화_자동차등록", "정규화_전기자동차등록"
]]

# %%
X = X.astype(float)
Cox = CoxPHSurvivalAnalysis().fit(X, np.array(list(map(int, y_train))))
# %%
np.array(df_LR[['w_SS']])
#%%
from sksurv.datasets import load_whas500
X, y = load_whas500()
X = X.astype(float)
estimator = CoxPHSurvivalAnalysis().fit(X, y)
chf_funcs = estimator.predict_cumulative_hazard_function(X.iloc[:10])

for fn in chf_funcs:
    plt.step(fn.x, fn(fn.x), where="post")

plt.ylim(0, 1)
plt.show()
# %%
X
y

df_LR[df_LR['SS_station'] == 1]