-
Notifications
You must be signed in to change notification settings - Fork 0
/
MERT.py
139 lines (120 loc) · 4.91 KB
/
MERT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#!/usr/bin/env python
import optparse
import bleu
import os
import random
optparser = optparse.OptionParser()
optparser.add_option("-t", "--kbest-training-list", dest="training_input", default="workdir/train.with_len.txt", help="100-best training translation lists")
optparser.add_option("-r", "--reference", dest="reference", default="data/train.ref", help="Training target language reference sentences")
optparser.add_option("-k", "--kbest-list", dest="input", default="workdir/dev+test.with_len.txt", help="100-best testing translation lists")
(opts, _) = optparser.parse_args()
def get_reference_sentences():
return [line.strip().split() for line in open(opts.reference)]
def get_candidate_translations():
return [hyp.split(' ||| ') for hyp in open(opts.training_input)]
def update_param(feature, current_param_dict):
sentence_dict = {}
for m in xrange(0, num_sents):
ref = reference[m]
candidates = all_hyps[m * 100:m * 100 + 100]
line_dict, steepest_line = define_sentence_lines(feature, candidates, current_param_dict)
sequence = find_line_sequence(line_dict, [(steepest_line, -999999)])
interval_stats_dict = {}
for candidate, interval_start, interval_end in sequence:
interval_stats_dict[(interval_start, interval_end)] = list(bleu.bleu_stats(candidate[2].split(), ref))
sentence_dict[m] = interval_stats_dict
all_interval_ends = sorted(set([item[1] for sublist in [dict.keys() for dict in sentence_dict.values()] for item in sublist]))
best_interval, best_BLEU = choose_best_interval(all_interval_ends, sentence_dict)
return_param_dict = {}
for f in current_param_dict:
return_param_dict[f] = current_param_dict[f] if f != feature else sum(best_interval)/2
return return_param_dict, best_BLEU
def choose_best_interval(interval_ends, sentence_dict):
current_best = 0
best_interval = (0,0)
for i in range(len(interval_ends)):
all_stats = collect_BLEU_stats(sentence_dict, interval_ends[i])
new_BLEU = 100*bleu.bleu(all_stats)
if new_BLEU > current_best:
current_best = new_BLEU
best_interval = (interval_ends[i-1], interval_ends[i])
return best_interval, current_best
def collect_BLEU_stats(iter_dict, interval_end):
all_stats = [0 for j in xrange(10)]
for m in iter_dict:
sentence_stats = [iter_dict[m][k] for k in iter_dict[m].keys() if k[0] < interval_end <= k[1]][0]
all_stats = [sum(scores) for scores in zip(all_stats, sentence_stats)]
return all_stats
def define_sentence_lines(target_feature, candidates, current_param_dict):
sentence_lines = {}
steepest_line = (99999, 0, [])
for (n, (num, hyp, feats)) in enumerate(candidates):
b = 0.0
a = 0.0
for feat in feats.split(' '):
k, v = feat.split('=')
if k == target_feature:
a = float(v)
else:
b += float(v)*current_param_dict[k]
sentence_lines[n] = (a, b, hyp)
if a < steepest_line[0]:
steepest_line = (a, b, hyp)
return sentence_lines, steepest_line
def find_line_sequence(line_dict, sequence):
current_line = sequence[-1][0]
current_intercept = sequence[-1][1]
next = find_next_line(current_line, line_dict, current_intercept)
if next[1] == 999999:
sequence[-1] = (current_line, current_intercept, 999999)
return sequence
else:
sequence[-1] = (current_line, current_intercept, next[1])
sequence.append(next)
return find_line_sequence(line_dict, sequence)
def find_next_line(current_line, line_dict, last_intercept):
next_line = current_line
next_intercept = 999999
for id in line_dict:
if line_dict[id] != current_line:
intercept = get_intercept(current_line, line_dict[id])
if last_intercept < intercept < next_intercept:
next_intercept = intercept
next_line = line_dict[id]
return next_line, next_intercept
def get_intercept(line1, line2):
if line1[0] - line2[0]:
return (line2[1] - line1[1]) / (line1[0] - line2[0])
else:
return -1000000
reference = get_reference_sentences()
all_hyps = get_candidate_translations()
num_sents = len(reference)
param_dict = {}
best_BLEU = 0
for n in range(0,10):
run_param_dict = {}
_, _, feats = all_hyps[0]
for feat in feats.split():
k,_ = feat.split('=')
run_param_dict[k] = random.randint(-10, 10)
initial_weights = ' '.join(["{}={}".format(feat, run_param_dict[feat]) for feat in run_param_dict])
update_order = random.sample(run_param_dict.keys(), len(run_param_dict.keys()))
previous_BLEU = 0
current_BLEU = 1
iteration = 0
last_updated = ''
to_update = [f for f in update_order if f != last_updated]
while current_BLEU - previous_BLEU > 0.0001:
previous_BLEU = current_BLEU
improvements = [(feat, update_param(feat, run_param_dict)) for feat in to_update]
best = sorted(improvements, key=lambda x: x[1][1])[-1]
current_BLEU = best[1][1]
run_param_dict = best[1][0]
last_updated = best[0]
iteration += 1
if current_BLEU > best_BLEU:
param_dict = run_param_dict
best_BLEU = current_BLEU
weights = ' '.join(["{}={}".format(feat, param_dict[feat]) for feat in param_dict])
os.system("python rerank -k {} -w '{}'".format(opts.input, weights))