def test_survival_squared_hinge_loss(self): x, y = self.get_data_without_ties() nrsvm = NaiveSurvivalSVM(loss='squared_hinge', dual=False, tol=8e-7, max_iter=1000, random_state=0) nrsvm.fit(x, y) rsvm = FastSurvivalSVM(optimizer='avltree', tol=8e-7, max_iter=1000, random_state=0) rsvm.fit(x, y) assert_array_almost_equal(nrsvm.coef_.ravel(), rsvm.coef_, 3) pred_nrsvm = nrsvm.predict(x) pred_rsvm = rsvm.predict(x) self.assertEqual(len(pred_nrsvm), len(pred_rsvm)) c1 = concordance_index_censored(y['fstat'], y['lenfol'], pred_nrsvm) c2 = concordance_index_censored(y['fstat'], y['lenfol'], pred_rsvm) self.assertAlmostEqual(c1[0], c2[0]) self.assertTupleEqual(c1[1:], c2[1:])
def test_survival_squared_hinge_loss(whas500_without_ties): x, y = whas500_without_ties nrsvm = NaiveSurvivalSVM(loss='squared_hinge', dual=False, tol=8e-7, max_iter=1000, random_state=0) nrsvm.fit(x, y) rsvm = FastSurvivalSVM(optimizer='avltree', tol=8e-7, max_iter=1000, random_state=0) rsvm.fit(x, y) assert_array_almost_equal(nrsvm.coef_.ravel(), rsvm.coef_, 3) pred_nrsvm = nrsvm.predict(x) pred_rsvm = rsvm.predict(x) assert len(pred_nrsvm) == len(pred_rsvm) expected_cindex = concordance_index_censored(y['fstat'], y['lenfol'], pred_nrsvm) assert_cindex_almost_equal(y['fstat'], y['lenfol'], pred_rsvm, expected_cindex)