def test_energy_discrete():
    for inference_method in get_installed(["qpbo", "ad3"]):
        crf = EdgeFeatureGraphCRF(n_states=3,
                                  inference_method=inference_method,
                                  n_edge_features=2, n_features=3)
        for i in range(10):
            x = np.random.normal(size=(7, 8, 3))
            edge_list = make_grid_edges(x, 4, return_lists=True)
            edges = np.vstack(edge_list)
            edge_features = edge_list_to_features(edge_list)
            x = (x.reshape(-1, 3), edges, edge_features)

            unary_params = np.random.normal(size=(3, 3))
            pw1 = np.random.normal(size=(3, 3))
            pw2 = np.random.normal(size=(3, 3))
            w = np.hstack([unary_params.ravel(), pw1.ravel(), pw2.ravel()])
            y_hat = crf.inference(x, w, relaxed=False)
            energy = compute_energy(crf._get_unary_potentials(x, w),
                                    crf._get_pairwise_potentials(x, w), edges,
                                    y_hat)

            joint_feature = crf.joint_feature(x, y_hat)
            energy_svm = np.dot(joint_feature, w)

            assert_almost_equal(energy, energy_svm)
def test_energy_discrete():
    for inference_method in get_installed(["qpbo", "ad3"]):
        crf = EdgeFeatureGraphCRF(n_states=3,
                                  inference_method=inference_method,
                                  n_edge_features=2, n_features=3)
        for i in xrange(10):
            x = np.random.normal(size=(7, 8, 3))
            edge_list = make_grid_edges(x, 4, return_lists=True)
            edges = np.vstack(edge_list)
            edge_features = edge_list_to_features(edge_list)
            x = (x.reshape(-1, 3), edges, edge_features)

            unary_params = np.random.normal(size=(3, 3))
            pw1 = np.random.normal(size=(3, 3))
            pw2 = np.random.normal(size=(3, 3))
            w = np.hstack([unary_params.ravel(), pw1.ravel(), pw2.ravel()])
            y_hat = crf.inference(x, w, relaxed=False)
            energy = compute_energy(crf._get_unary_potentials(x, w),
                                    crf._get_pairwise_potentials(x, w), edges,
                                    y_hat)

            joint_feature = crf.joint_feature(x, y_hat)
            energy_svm = np.dot(joint_feature, w)

            assert_almost_equal(energy, energy_svm)
Exemple #3
0
def test_unary_potentials():
    #print "---SIMPLE---------------------------------------------------------------------"
    #g, (node_f, edges, edge_f) = get_simple_graph_structure(), get_simple_graph()

    g = NodeTypeEdgeFeatureGraphCRF(
                    1                   #how many node type?
                 , [4]                  #how many labels   per node type?
                 , [3]                  #how many features per node type?
                 , np.array([[3]])      #how many features per node type X node type?                  
                 )
    node_f = [ np.array([[1,1,1], 
                         [2,2,2]]) 
              ]
    edges  = [ np.array([[0,1]]) 
              ]    #an edge from 0 to 1
    edge_f = [ np.array([[3,3,3]]) 
              ]
    x = (node_f, edges, edge_f)
    #print "- - - - - - - - - - - - - - - - - - - - - - - - - - - "
    y = np.hstack([ np.array([1,2])])
#     y = np.array([1,0])
    #print y
    g.initialize(x, y)
    
    gref = EdgeFeatureGraphCRF(4,3,3)
    xref = (node_f[0], edges[0], edge_f[0])
    wref = np.arange(gref.size_joint_feature)
    potref = gref._get_unary_potentials(xref, wref)
    #print `potref`
    
    w = np.arange(g.size_joint_feature)
    pot = g._get_unary_potentials(x, w)
    #print `pot`
    assert_array_equal(pot, [potref])

    pwpotref = gref._get_pairwise_potentials(xref, wref)
    #print `pwpotref`
    pwpot = g._get_pairwise_potentials(x, w)
    #print `pwpot`
    assert_array_equal(pwpot, [pwpotref])