-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
75 lines (60 loc) · 2.01 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
__author__ = 'luoshalin'
import sys
from preprocess import preprocess
from preprocess import analyze
from forwardBackword import gen_matrix
from forwardBackword import forward
from forwardBackword import backward
from forwardBackword import forward_backward
from forwardBackword import plot_pcll
from forwardBackword import vtb_decode
from forwardBackword import ll_decode
def main(argv):
# x = sys.argv[2]
# hmm_train_file_path = '../../data/hmm_train_data'
# hmm_train_file_path = '../../data/hmm_test_data'
# hmm_train_file_path = '../../data/hmm_train_data_jpn'
hmm_train_file_path = '../../data/hmm_test_data_jpn'
vtb_train_file_path = '../../data/viterbi_train_data'
# F-B stopping criteria
threshold = 1e-5
param_log_filepath = 'log1'
# preprocess
input_str = preprocess(hmm_train_file_path)
# do analysis
analyze(input_str)
# forward & backward algorithms
index_dic, A, B = gen_matrix()
pcll_old = -1000.0
pcll_new = -1000.0
pcll_list = []
itr = 1
while True:
alpha_table, pcll_alpha = forward(input_str, A, B, index_dic)
beta_table, pcll_beta = backward(input_str, A, B, index_dic)
A, B = forward_backward(alpha_table, beta_table, A, B, index_dic, input_str)
# update pcll
pcll_new = pcll_alpha
pcll_list.append(pcll_new)
if abs(pcll_new - pcll_old) < threshold:
break
print 'ITERATION#' + str(itr) + ': ' + str(pcll_new)
# update
pcll_old = pcll_new
itr += 1
# plot
plot_pcll(pcll_list)
print '\n=========FINAL A & B==========='
print 'A:'
print A
print '\nB:'
print B
print '==============================\n'
# viterbi decoder
vtb_input_str = preprocess(vtb_train_file_path)
vtb_hidden_state_list = vtb_decode(vtb_input_str, A, B, index_dic)
ll_hidden_state_list = ll_decode(vtb_input_str, B, index_dic)
print ll_hidden_state_list
print 'END!'
if __name__ == '__main__':
main(sys.argv[1:])