示例#1
0
    def test_fit_and_predict_rbf(make_whas500, optimizer):
        whas500 = make_whas500(to_numeric=True)
        ssvm = FastKernelSurvivalSVM(optimizer=optimizer,
                                     kernel='rbf',
                                     tol=2e-6,
                                     max_iter=75,
                                     random_state=0)
        ssvm.fit(whas500.x, whas500.y)

        assert not ssvm._get_tags()["pairwise"]
        assert whas500.x.shape[0] == ssvm.coef_.shape[0]

        c = ssvm.score(whas500.x, whas500.y)
        assert c >= 0.965
示例#2
0
    def test_fit_and_predict_linear_precomputed(make_whas500):
        whas500 = make_whas500(to_numeric=True)
        ssvm = FastKernelSurvivalSVM(optimizer="rbtree",
                                     kernel='precomputed',
                                     random_state=0)
        x = numpy.dot(whas500.x, whas500.x.T)
        ssvm.fit(x, whas500.y)

        assert ssvm._get_tags()["pairwise"]
        assert whas500.x.shape[0] == ssvm.coef_.shape[0]

        i = numpy.arange(250)
        numpy.random.RandomState(0).shuffle(i)
        c = ssvm.score(x[i], whas500.y[i])
        assert round(abs(c - 0.76923445664157997), 6) == 0
示例#3
0
    def test_fit_and_predict_hybrid_rbf(make_whas500):
        whas500 = make_whas500(to_numeric=True)
        ssvm = FastKernelSurvivalSVM(optimizer="rbtree",
                                     rank_ratio=0.5,
                                     kernel="rbf",
                                     max_iter=50,
                                     fit_intercept=True,
                                     random_state=0)
        ssvm.fit(whas500.x, whas500.y)

        assert not ssvm._get_tags()["pairwise"]
        assert abs(5.0289145697617164 - ssvm.intercept_) <= 0.04

        pred = ssvm.predict(whas500.x)
        rmse = numpy.sqrt(mean_squared_error(whas500.y['lenfol'], pred))
        assert abs(880.20361811281487 - rmse) <= 75
示例#4
0
    def test_fit_and_predict_regression_rbf(make_whas500):
        whas500 = make_whas500(to_numeric=True)
        ssvm = FastKernelSurvivalSVM(optimizer="rbtree",
                                     rank_ratio=0.0,
                                     kernel="rbf",
                                     tol=1e-6,
                                     max_iter=50,
                                     fit_intercept=True,
                                     random_state=0)
        ssvm.fit(whas500.x, whas500.y)

        assert not ssvm._get_tags()["pairwise"]
        assert round(abs(ssvm.intercept_ - 4.9267218894089533), 7) == 0

        pred = ssvm.predict(whas500.x)
        rmse = numpy.sqrt(mean_squared_error(whas500.y['lenfol'], pred))
        assert round(abs(rmse - 783.525277), 6) == 0
示例#5
0
    def test_fit_and_predict_clinical_kernel(make_whas500):
        whas500 = make_whas500(to_numeric=True)

        trans = ClinicalKernelTransform()
        trans.fit(whas500.x_data_frame)

        ssvm = FastKernelSurvivalSVM(optimizer="rbtree",
                                     kernel=trans.pairwise_kernel,
                                     tol=7e-7,
                                     max_iter=100,
                                     random_state=0)
        ssvm.fit(whas500.x, whas500.y)

        assert not ssvm._get_tags()["pairwise"]
        assert whas500.x.shape[0] == ssvm.coef_.shape[0]

        c = ssvm.score(whas500.x, whas500.y)
        assert c >= 0.854
示例#6
0
    def test_fit_and_predict_linear_regression_precomputed(make_whas500):
        whas500 = make_whas500(to_numeric=True)
        ssvm = FastKernelSurvivalSVM(optimizer="rbtree",
                                     rank_ratio=0.0,
                                     kernel="precomputed",
                                     max_iter=50,
                                     tol=1e-8,
                                     fit_intercept=True,
                                     random_state=0)
        x = numpy.dot(whas500.x, whas500.x.T)
        ssvm.fit(x, whas500.y)

        assert ssvm._get_tags()["pairwise"]
        assert round(abs(ssvm.intercept_ - 6.416017539824949), 5) == 0

        i = numpy.arange(250)
        numpy.random.RandomState(0).shuffle(i)
        pred = ssvm.predict(x[i])
        rmse = numpy.sqrt(mean_squared_error(whas500.y['lenfol'][i], pred))
        assert rmse <= 1342.274550652291 + 0.293