def test_survival_squared_hinge_loss(self): x, y = self.get_data_without_ties() nrsvm = NaiveSurvivalSVM(loss='squared_hinge', dual=False, tol=8e-7, max_iter=1000, random_state=0) nrsvm.fit(x, y) rsvm = FastSurvivalSVM(optimizer='avltree', tol=8e-7, max_iter=1000, random_state=0) rsvm.fit(x, y) assert_array_almost_equal(nrsvm.coef_.ravel(), rsvm.coef_, 3) pred_nrsvm = nrsvm.predict(x) pred_rsvm = rsvm.predict(x) self.assertEqual(len(pred_nrsvm), len(pred_rsvm)) c1 = concordance_index_censored(y['fstat'], y['lenfol'], pred_nrsvm) c2 = concordance_index_censored(y['fstat'], y['lenfol'], pred_rsvm) self.assertAlmostEqual(c1[0], c2[0]) self.assertTupleEqual(c1[1:], c2[1:])
def test_survival_squared_hinge_loss(whas500_without_ties): x, y = whas500_without_ties nrsvm = NaiveSurvivalSVM(loss='squared_hinge', dual=False, tol=8e-7, max_iter=1000, random_state=0) nrsvm.fit(x, y) rsvm = FastSurvivalSVM(optimizer='avltree', tol=8e-7, max_iter=1000, random_state=0) rsvm.fit(x, y) assert_array_almost_equal(nrsvm.coef_.ravel(), rsvm.coef_, 3) pred_nrsvm = nrsvm.predict(x) pred_rsvm = rsvm.predict(x) assert len(pred_nrsvm) == len(pred_rsvm) expected_cindex = concordance_index_censored(y['fstat'], y['lenfol'], pred_nrsvm) assert_cindex_almost_equal(y['fstat'], y['lenfol'], pred_rsvm, expected_cindex)
def test_fit_uncomparable(whas500_uncomparable): ssvm = NaiveSurvivalSVM(loss='squared_hinge', dual=False, tol=1e-8, max_iter=1000, random_state=0) with pytest.raises(NoComparablePairException): ssvm.fit(whas500_uncomparable.x, whas500_uncomparable.y)
def test_fit_with_ties(whas500_with_ties): x, y = whas500_with_ties nrsvm = NaiveSurvivalSVM(loss='squared_hinge', dual=False, tol=1e-8, max_iter=1000, random_state=0) nrsvm.fit(x, y) assert nrsvm.coef_.shape == (1, 14) cindex = nrsvm.score(x, y) assert round(abs(cindex - 0.7760582309811175), 7) == 0
def test_fit_with_ties(self): x, y = self.get_data_with_ties() nrsvm = NaiveSurvivalSVM(loss='squared_hinge', dual=False, tol=1e-8, max_iter=1000, random_state=0) nrsvm.fit(x, y) self.assertTupleEqual(nrsvm.coef_.shape, (1, 14)) cindex = nrsvm.score(x, y) self.assertAlmostEqual(cindex, 0.7760582309811175)