def test_graph_crf_loss_augment(): x = (x_1, g_1) y = y_1 crf = GraphCRF(n_states=2, inference_method='lp') y_hat, energy = crf.loss_augmented_inference(x, y, w, return_energy=True) # check that y_hat fulfulls energy + loss condition assert_almost_equal(np.dot(w, crf.psi(x, y_hat)) + crf.loss(y, y_hat), -energy)