Exemple #1
0
def test_graph_crf_class_weights():
    # no edges
    crf = GraphCRF(n_states=3, n_features=3)
    w = np.array([
        1,
        0,
        0,  # unary
        0,
        1,
        0,
        0,
        0,
        1,
        0,  # pairwise
        0,
        0,
        0,
        0,
        0
    ])
    x = (np.array([[1, 1.5, 1.1]]), np.empty((0, 2)))
    assert_equal(crf.inference(x, w), 1)
    # loss augmented inference picks last
    assert_equal(crf.loss_augmented_inference(x, [1], w), 2)

    # with class-weights, loss for class 1 is smaller, loss-augmented inference
    # will find it
    crf = GraphCRF(n_states=3, n_features=3, class_weight=[1, .1, 1])
    assert_equal(crf.loss_augmented_inference(x, [1], w), 1)
Exemple #2
0
def test_graph_crf_class_weights():
    # no edges
    crf = GraphCRF(n_states=3, n_features=3, inference_method="dai")
    w = np.array([1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0])  # unary  # pairwise
    x = (np.array([[1, 1.5, 1.1]]), np.empty((0, 2)))
    assert_equal(crf.inference(x, w), 1)
    # loss augmented inference picks last
    assert_equal(crf.loss_augmented_inference(x, [1], w), 2)

    # with class-weights, loss for class 1 is smaller, loss-augmented inference
    # will find it
    crf = GraphCRF(n_states=3, n_features=3, inference_method="dai", class_weight=[1, 0.1, 1])
    assert_equal(crf.loss_augmented_inference(x, [1], w), 1)
Exemple #3
0
def test_graph_crf_loss_augment():
    x = (x_1, g_1)
    y = y_1
    crf = GraphCRF(n_states=2, inference_method="lp")
    y_hat, energy = crf.loss_augmented_inference(x, y, w, return_energy=True)
    # check that y_hat fulfills energy + loss condition
    assert_almost_equal(np.dot(w, crf.psi(x, y_hat)) + crf.loss(y, y_hat), -energy)
Exemple #4
0
def test_graph_crf_loss_augment():
    x = (x_1, g_1)
    y = y_1
    crf = GraphCRF()
    crf.initialize([x], [y])
    y_hat, energy = crf.loss_augmented_inference(x, y, w, return_energy=True)
    # check that y_hat fulfills energy + loss condition
    assert_almost_equal(np.dot(w, crf.joint_feature(x, y_hat)) + crf.loss(y, y_hat),
                        -energy)