def test_graph_crf_energy_lp_relaxed(): crf = GraphCRF(n_states=2, n_features=2) for i in xrange(10): w_ = np.random.uniform(size=w.shape) inf_res, energy_lp = crf.inference((x_1, g_1), w_, relaxed=True, return_energy=True) assert_almost_equal(energy_lp, -np.dot(w_, crf.joint_feature((x_1, g_1), inf_res))) # now with fractional solution x = np.array([[0, 0], [0, 0], [0, 0]]) inf_res, energy_lp = crf.inference((x, g_1), w, relaxed=True, return_energy=True) assert_almost_equal(energy_lp, -np.dot(w, crf.joint_feature((x, g_1), inf_res)))
def test_graph_crf_energy_lp_integral(): crf = GraphCRF(n_states=2, inference_method='lp', n_features=2) inf_res, energy_lp = crf.inference((x_1, g_1), w, relaxed=True, return_energy=True) # integral solution assert_array_almost_equal(np.max(inf_res[0], axis=-1), 1) y = np.argmax(inf_res[0], axis=-1) # energy and joint_feature check out assert_almost_equal(energy_lp, -np.dot(w, crf.joint_feature((x_1, g_1), y)), 4)
def test_graph_crf_loss_augment(): x = (x_1, g_1) y = y_1 crf = GraphCRF() crf.initialize([x], [y]) y_hat, energy = crf.loss_augmented_inference(x, y, w, return_energy=True) # check that y_hat fulfills energy + loss condition assert_almost_equal(np.dot(w, crf.joint_feature(x, y_hat)) + crf.loss(y, y_hat), -energy)