예제 #1
0
    def test_survival_squared_hinge_loss(self):
        x, y = self.get_data_without_ties()

        nrsvm = NaiveSurvivalSVM(loss='squared_hinge',
                                 dual=False,
                                 tol=8e-7,
                                 max_iter=1000,
                                 random_state=0)
        nrsvm.fit(x, y)

        rsvm = FastSurvivalSVM(optimizer='avltree',
                               tol=8e-7,
                               max_iter=1000,
                               random_state=0)
        rsvm.fit(x, y)

        assert_array_almost_equal(nrsvm.coef_.ravel(), rsvm.coef_, 3)

        pred_nrsvm = nrsvm.predict(x)
        pred_rsvm = rsvm.predict(x)

        self.assertEqual(len(pred_nrsvm), len(pred_rsvm))

        c1 = concordance_index_censored(y['fstat'], y['lenfol'], pred_nrsvm)
        c2 = concordance_index_censored(y['fstat'], y['lenfol'], pred_rsvm)

        self.assertAlmostEqual(c1[0], c2[0])
        self.assertTupleEqual(c1[1:], c2[1:])
예제 #2
0
    def test_survival_squared_hinge_loss(whas500_without_ties):
        x, y = whas500_without_ties

        nrsvm = NaiveSurvivalSVM(loss='squared_hinge',
                                 dual=False,
                                 tol=8e-7,
                                 max_iter=1000,
                                 random_state=0)
        nrsvm.fit(x, y)

        rsvm = FastSurvivalSVM(optimizer='avltree',
                               tol=8e-7,
                               max_iter=1000,
                               random_state=0)
        rsvm.fit(x, y)

        assert_array_almost_equal(nrsvm.coef_.ravel(), rsvm.coef_, 3)

        pred_nrsvm = nrsvm.predict(x)
        pred_rsvm = rsvm.predict(x)

        assert len(pred_nrsvm) == len(pred_rsvm)

        expected_cindex = concordance_index_censored(y['fstat'], y['lenfol'],
                                                     pred_nrsvm)
        assert_cindex_almost_equal(y['fstat'], y['lenfol'], pred_rsvm,
                                   expected_cindex)
예제 #3
0
 def test_fit_uncomparable(whas500_uncomparable):
     ssvm = NaiveSurvivalSVM(loss='squared_hinge',
                             dual=False,
                             tol=1e-8,
                             max_iter=1000,
                             random_state=0)
     with pytest.raises(NoComparablePairException):
         ssvm.fit(whas500_uncomparable.x, whas500_uncomparable.y)
    def test_fit_with_ties(whas500_with_ties):
        x, y = whas500_with_ties

        nrsvm = NaiveSurvivalSVM(loss='squared_hinge', dual=False, tol=1e-8, max_iter=1000, random_state=0)
        nrsvm.fit(x, y)

        assert nrsvm.coef_.shape == (1, 14)

        cindex = nrsvm.score(x, y)
        assert round(abs(cindex - 0.7760582309811175), 7) == 0
예제 #5
0
    def test_fit_with_ties(self):
        x, y = self.get_data_with_ties()

        nrsvm = NaiveSurvivalSVM(loss='squared_hinge',
                                 dual=False,
                                 tol=1e-8,
                                 max_iter=1000,
                                 random_state=0)
        nrsvm.fit(x, y)

        self.assertTupleEqual(nrsvm.coef_.shape, (1, 14))

        cindex = nrsvm.score(x, y)
        self.assertAlmostEqual(cindex, 0.7760582309811175)