forked from SnowMasaya/Chainer-Slack-Twitter-Dialogue
-
Notifications
You must be signed in to change notification settings - Fork 0
/
EncoderDecoderModel.py
152 lines (127 loc) · 6.3 KB
/
EncoderDecoderModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#'!/usr/bin/env python
#-*- coding:utf-8 -*-
#!/usr/bin/python3
import numpy as np
from chainer import Chain, Variable, cuda, functions, links, optimizer, optimizers, serializers
from chainer import link
import util.generators as gens
from util.functions import trace, fill_batch
from util.vocabulary import Vocabulary
from EncoderDecoder import EncoderDecoder
from Common_function import CommonFunction
import random
class EncoderDecoderModel:
def __init__(self, parameter_dict):
self.parameter_dict = parameter_dict
self.source = parameter_dict["source"]
self.target = parameter_dict["target"]
self.test_source = parameter_dict["test_source"]
self.test_target = parameter_dict["test_target"]
self.vocab = parameter_dict["vocab"]
self.embed = parameter_dict["embed"]
self.hidden = parameter_dict["hidden"]
self.epoch = parameter_dict["epoch"]
self.minibatch = parameter_dict["minibatch"]
self.generation_limit = parameter_dict["generation_limit"]
self.word2vec = parameter_dict["word2vec"]
self.word2vecFlag = parameter_dict["word2vecFlag"]
self.common_function = CommonFunction()
self.model = "ChainerDialogue"
self.encdec = parameter_dict["encdec"]
def forward(self, src_batch, trg_batch, src_vocab, trg_vocab, encdec, is_training, generation_limit):
pass
def train(self):
trace('making vocabularies ...')
src_vocab = Vocabulary.new(gens.word_list(self.source), self.vocab)
trg_vocab = Vocabulary.new(gens.word_list(self.target), self.vocab)
trace('making model ...')
encdec = EncoderDecoder(self.vocab, self.embed, self.hidden)
if self.word2vecFlag:
self.copy_model(self.word2vec, encdec.enc)
self.copy_model(self.word2vec, encdec.dec, dec_flag=True)
else:
encdec = self.encdec
for epoch in range(self.epoch):
trace('epoch %d/%d: ' % (epoch + 1, self.epoch))
trained = 0
gen1 = gens.word_list(self.source)
gen2 = gens.word_list(self.target)
gen3 = gens.batch(gens.sorted_parallel(gen1, gen2, 100 * self.minibatch), self.minibatch)
opt = optimizers.AdaGrad(lr = 0.01)
opt.setup(encdec)
opt.add_hook(optimizer.GradientClipping(5))
random_number = random.randint(0, self.minibatch - 1)
for src_batch, trg_batch in gen3:
src_batch = fill_batch(src_batch)
trg_batch = fill_batch(trg_batch)
K = len(src_batch)
hyp_batch, loss = self.forward(src_batch, trg_batch, src_vocab, trg_vocab, encdec, True, 0)
loss.backward()
opt.update()
if trained == 0:
self.print_out(random_number, epoch, trained, src_batch, trg_batch, hyp_batch)
trained += K
trace('saving model ...')
prefix = self.model
src_vocab.save(prefix + '.srcvocab')
trg_vocab.save(prefix + '.trgvocab')
encdec.save_spec(prefix + '.spec')
serializers.save_hdf5(prefix + '.weights', encdec)
trace('finished.')
def test(self):
trace('loading model ...')
src_vocab = Vocabulary.load(self.model + '.srcvocab')
trg_vocab = Vocabulary.load(self.model + '.trgvocab')
encdec = EncoderDecoder.load_spec(self.model + '.spec')
serializers.load_hdf5(self.model + '.weights', encdec)
trace('generating translation ...')
generated = 0
with open(self.target, 'w') as fp:
for src_batch in gens.batch(gens.word_list(self.source), self.minibatch):
src_batch = fill_batch(src_batch)
K = len(src_batch)
trace('sample %8d - %8d ...' % (generated + 1, generated + K))
hyp_batch = self.forward(src_batch, None, src_vocab, trg_vocab, encdec, False, self.generation_limit)
source_cuont = 0
for hyp in hyp_batch:
hyp.append('</s>')
hyp = hyp[:hyp.index('</s>')]
print("src : " + "".join(src_batch[source_cuont]).replace("</s>", ""))
print('hyp : ' +''.join(hyp))
print(' '.join(hyp), file=fp)
source_cuont = source_cuont + 1
generated += K
trace('finished.')
def print_out(self, K, i_epoch, trained, src_batch, trg_batch, hyp_batch):
trace('epoch %3d/%3d, sample %8d' % (i_epoch + 1, self.epoch, trained + K + 1))
trace(' src = ' + ' '.join([x if x != '</s>' else '*' for x in src_batch[K]]))
trace(' trg = ' + ' '.join([x if x != '</s>' else '*' for x in trg_batch[K]]))
trace(' hyp = ' + ' '.join([x if x != '</s>' else '*' for x in hyp_batch[K]]))
def copy_model(self, src, dst, dec_flag=False):
print("start copy")
for child in src.children():
if dec_flag:
if dst["weight_jy"] and child.name == "weight_xi" and self.word2vecFlag:
for a, b in zip(child.namedparams(), dst["weight_jy"].namedparams()):
b[1].data = a[1].data
print('Copy weight_jy')
if child.name not in dst.__dict__: continue
dst_child = dst[child.name]
if type(child) != type(dst_child): continue
if isinstance(child, link.Chain):
self.copy_model(child, dst_child)
if isinstance(child, link.Link):
match = True
for a, b in zip(child.namedparams(), dst_child.namedparams()):
if a[0] != b[0]:
match = False
break
if a[1].data.shape != b[1].data.shape:
match = False
break
if not match:
print('Ignore %s because of parameter mismatch' % child.name)
continue
for a, b in zip(child.namedparams(), dst_child.namedparams()):
b[1].data = a[1].data
print('Copy %s' % child.name)