/
decoder.py
90 lines (76 loc) · 3.57 KB
/
decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch.nn as nn
import torch.autograd
from read import Reader
from common import getMaxIndex
from common import is_end_label
class Decoder(nn.Module):
def __init__(self, hyperParams):
super(Decoder, self).__init__()
reader = Reader()
self.wordEmb, self.wordDim = reader.load_pretrain(hyperParams.wordEmbFile, hyperParams.wordAlpha, hyperParams.unk)
self.wordEmb.weight.requires_grad = hyperParams.wordFineTune
self.dropOut = torch.nn.Dropout(hyperParams.dropProb)
self.lastWords = []
self.hyperParams = hyperParams
#self.linearLayer = nn.Linear(hyperParams.rnnHiddenSize * 2, hyperParams.labelSize)
self.linearLayer = nn.Linear(hyperParams.rnnHiddenSize * 2 + self.wordDim, hyperParams.labelSize)
self.softmax = nn.LogSoftmax()
def forward(self, batch, encoder_output, exams, bTrain=False):
sent_len = encoder_output.size()[1]
self.lastWords = []
batch_labels = []
last_word_indexes = torch.autograd.Variable(torch.LongTensor(batch))
output = []
for idy in range(batch):
self.lastWords.append('<s>')
batch_labels.append([])
output.append([])
last_word_indexes.data[idy] = self.hyperParams.wordSTARTID
for idx in range(sent_len):
char_presentation = encoder_output.permute(1, 0, 2)[idx]
last_word_presentation = self.wordEmb(last_word_indexes)
last_word_presentation = self.dropOut(last_word_presentation)
concat = torch.cat((char_presentation, last_word_presentation), 1)
hidden = self.linearLayer(concat)
#print(hidden.size())
batch_hidden = torch.chunk(hidden, batch, 0)
#print(batch_out[0].size())
for idy in range(batch):
output[idy].append(batch_hidden[idy])
labelID = getMaxIndex(self.hyperParams, hidden[idy])
label = self.hyperParams.labelAlpha.from_id(labelID)
batch_labels[idy].append(label)
if bTrain:
self.prepare(exams[idy].m_char, idx, exams[idy].m_label, idy)
else:
self.prepare(exams[idy].m_char, idx, batch_labels[idy], idy)
wordID = self.hyperParams.wordAlpha.from_string(self.lastWords[idy])
if wordID < 0:
wordID = self.hyperParams.wordUNKID
last_word_indexes.data[idy] = wordID
for idy in range(batch):
output[idy] = torch.cat(output[idy], 0)
output = torch.cat(output, 0)
output = self.softmax(output)
return output
'''
linear = self.linearLayer(torch.cat(encoder_output, 0))
output = self.softmax(linear)
print(output.size())
return output
'''
def prepare(self, m_char, index, labels, batchIndex):
if index < len(m_char):
if labels[index][0] == 'S' or labels[index][0] == 's':
self.lastWords[batchIndex] = m_char[index]
if labels[index][0] == 'E' or labels[index][0] == 'e':
tmp_word = m_char[index]
idx = index - 1
while (idx >= 0) and (labels[idx][0] == 'M' or labels[idx][0] == 'm'):
tmp_word += m_char[idx]
idx -= 1
if idx >= 0 and (labels[idx][0] == 'B' or labels[idx][0] == 'b'):
tmp_word += m_char[idx]
self.lastWords[batchIndex] = tmp_word[::-1]
else:
self.lastWords[batchIndex] = '-null-'