-
Notifications
You must be signed in to change notification settings - Fork 0
/
dd_k_best.py
127 lines (108 loc) · 3.13 KB
/
dd_k_best.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# /usr/bin/python
from __future__ import division
import sys, math
import fst_search, viterbi
def init_dd_param(dd_u, n, tagset):
for i in xrange(0, n):
dd_u[i] = {}#defaultdict()
for t in tagset:
dd_u[i][t] = 0
# can be made faster, use dictionary shallow copying
def compute_indicators(seq, tagset):
ind = {}
for i in xrange(0, len(seq)):
z = {}
for t in tagset:
if seq[i] == t:
z[t] = 1
else:
z[t] = 0
ind[i] = z
return ind
'''
Executes the dual decomposition algorithm to get the k-best
list of sequences
'''
def run(sentence, tagset, hmm, k_best_list):
max_iter = len(k_best_list)*200
n = len(sentence)
k = len(k_best_list)
u = [] # dd parameter list
for j in range(k+1):
u_j = {}
init_dd_param(u_j, n, tagset)
u.append(u_j)
w = {}
init_dd_param(w, n, tagset)
ku = {}
init_dd_param(ku, n, tagset)
iteration = 1
while iteration <= max_iter:
#print iteration
step_size = 21.0 / math.sqrt(iteration)
#print "step size", step_size
seqs = []
indicators = []
for i in u[0].iterkeys():
for t in u[0][i].iterkeys():
ku[i][t] = -1 * u[0][i][t]
seq1, score1, score2 = viterbi.run(sentence, tagset, hmm, ku)
seqs.append(seq1)
indicators.append(compute_indicators(seq1, tagset))
#print 0, ' '.join(seq1)
for j in range(k):
seq, fst_score = fst_search.run(k_best_list[j], u[j+1], tagset)
#print j+1, ' '.join(seq)
seqs.append(seq)
indicators.append(compute_indicators(seq, tagset))
# check for agreement
agree = True
for seq in seqs[1:]:
if seq != seq1:
agree = False
break
if agree == False:
update(indicators, u, w, step_size)
else:
return seq1, iteration
iteration += 1
return seq1, -1
'''
Update
'''
def update(indicators, u, w, step_size):
n = len(w)
k = len(indicators)
# j = 0
# for u_j in u:
# print "dd param for", j
# print "\t".join(u_j[0].keys())
# for i in u_j.iterkeys():
# for t in u_j[i].iterkeys():
# print "{0:.2f}".format(u_j[i][t]) + "\t",
# print
# j += 1
# break
#sys.stderr.write(str(n*len(w[0])) + "\n")
for i in range(n):
for t in w[i]:
sum_ind = 0.0
for ind in indicators:
sum_ind += ind[i][t]
w[i][t] = sum_ind/k
for i in range(n):
for t in w[i]:
j = 0
for ind in indicators:
u[j][i][t] = u[j][i][t] - step_size * (ind[i][t] - w[i][t])
j += 1
check_dd_param(u)
def check_dd_param(u):
for i in u[0].iterkeys():
for t in u[0][i].iterkeys():
s = 0
for u_j in u:
s += u_j[i][t]
if s > 0.00000001:
print "DD IS WRONG"
return