def test_n_slack_svm_as_crf_pickling(): iris = load_iris() X, y = iris.data, iris.target X_ = [(np.atleast_2d(x), np.empty((0, 2), dtype=np.int)) for x in X] Y = y.reshape(-1, 1) X_train, X_test, y_train, y_test = train_test_split(X_, Y, random_state=1) _, file_name = mkstemp() pbl = GraphCRF(n_features=4, n_states=3, inference_method='lp') logger = SaveLogger(file_name) svm = NSlackSSVM(pbl, C=100, n_jobs=1, logger=logger) svm.fit(X_train, y_train) assert_less(.97, svm.score(X_test, y_test)) assert_less(.97, logger.load().score(X_test, y_test))
def test_svm_as_crf_pickling_batch(): iris = load_iris() X, y = iris.data, iris.target X_ = [(np.atleast_2d(x), np.empty((0, 2), dtype=np.int)) for x in X] Y = y.reshape(-1, 1) X_train, X_test, y_train, y_test = train_test_split(X_, Y, random_state=1) _, file_name = mkstemp() pbl = GraphCRF(n_features=4, n_states=3, inference_method='unary') logger = SaveLogger(file_name) svm = FrankWolfeSSVM(pbl, C=10, logger=logger, max_iter=50, batch_mode=False) svm.fit(X_train, y_train) assert_less(.97, svm.score(X_test, y_test)) assert_less(.97, logger.load().score(X_test, y_test))