def decode_step(self, trg_words, train): """One step decoding.""" x = F.pick(self.trg_lookup, trg_words, 1) x = F.dropout(x, self.dropout_rate, train) h = self.trg_lstm.forward(x) h = F.dropout(h, self.dropout_rate, train) return self.why @ h + self.by
def decode_step(self, trg_words, train): """One step decoding.""" e = F.pick(self.trg_lookup, trg_words, 1) e = F.dropout(e, self.dropout_rate, train) h = self.trg_lstm.forward(F.concat([e, self.feed], 0)) h = F.dropout(h, self.dropout_rate, train) atten_probs = F.softmax(self.t_concat_fb @ h, 0) c = self.concat_fb @ atten_probs self.feed = F.tanh(self.whj @ F.concat([h, c], 0) + self.bj) return self.wjy @ self.feed + self.by
def forward(self, inputs): batch_size = len(inputs[0]) wlookup = F.parameter(self.pwlookup) wxs = F.parameter(self.pwxs) wsy = F.parameter(self.pwsy) s = F.zeros(Shape([NUM_HIDDEN_UNITS], batch_size)) outputs = [] for i in range(len(inputs) - 1): w = F.pick(wlookup, inputs[i], 1) x = w + s s = F.sigmoid(wxs @ x) outputs.append(wsy @ s) return outputs
def encode(self, src_batch, train): """Encodes source sentences and prepares internal states.""" # Reversed encoding. src_lookup = F.parameter(self.psrc_lookup) self.src_lstm.restart() for it in src_batch: x = F.pick(src_lookup, it, 1) x = F.dropout(x, self.dropout_rate, train) self.src_lstm.forward(x) # Initializes decoder states. self.trg_lookup = F.parameter(self.ptrg_lookup) self.why = F.parameter(self.pwhy) self.by = F.parameter(self.pby) self.trg_lstm.restart(self.src_lstm.get_c(), self.src_lstm.get_h())
def decode_step(self, trg_words, train): sentence_len = self.concat_fb.shape()[1] b = self.whw_ @ self.trg_lstm_.get_h() b = F.reshape(b, Shape([1, b.shape()[0]])) b = F.broadcast(b, 0, sentence_len) x = F.tanh(self.t_concat_fb @ self.wfbw_ + b) atten_prob = F.softmax(x @ self.wwe_, 0) c = self.concat_fb @ atten_prob e = F.pick(self.trg_lookup_, trg_words, 1) e = F.dropout(e, self.dropout_rate_, train) h = self.trg_lstm_.forward(F.concat([e, c], 0)) h = F.dropout(h, self.dropout_rate_, train) j = F.tanh(self.whj_ @ h + self.bj_) return self.wjy_ @ j + self.by_
def forward(self, inputs, train): batch_size = len(inputs[0]) lookup = F.parameter(self.plookup) self.rnn1.restart() self.rnn2.restart() self.hy.reset() outputs = [] for i in range(len(inputs) - 1): x = F.pick(lookup, inputs[i], 1) x = F.dropout(x, DROPOUT_RATE, train) h1 = self.rnn1.forward(x) h1 = F.dropout(h1, DROPOUT_RATE, train) h2 = self.rnn2.forward(h1) h2 = F.dropout(h2, DROPOUT_RATE, train) outputs.append(self.hy.forward(h2)) return outputs
def encode(self, src_batch, train): """Encodes source sentences and prepares internal states.""" # Embedding lookup. src_lookup = F.parameter(self.psrc_lookup) e_list = [] for x in src_batch: e = F.pick(src_lookup, x, 1) e = F.dropout(e, self.dropout_rate, train) e_list.append(e) # Forward encoding self.src_fw_lstm.restart() f_list = [] for e in e_list: f = self.src_fw_lstm.forward(e) f = F.dropout(f, self.dropout_rate, train) f_list.append(f) # Backward encoding self.src_bw_lstm.restart() b_list = [] for e in reversed(e_list): b = self.src_bw_lstm.forward(e) b = F.dropout(b, self.dropout_rate, train) b_list.append(b) b_list.reverse() # Concatenates RNN states. fb_list = [f_list[i] + b_list[i] for i in range(len(src_batch))] self.concat_fb = F.concat(fb_list, 1) self.t_concat_fb = F.transpose(self.concat_fb) # Initializes decode states. embed_size = self.psrc_lookup.shape()[0] self.trg_lookup = F.parameter(self.ptrg_lookup) self.whj = F.parameter(self.pwhj) self.bj = F.parameter(self.pbj) self.wjy = F.parameter(self.pwjy) self.by = F.parameter(self.pby) self.feed = F.zeros([embed_size]) self.trg_lstm.restart( self.src_fw_lstm.get_c() + self.src_bw_lstm.get_c(), self.src_fw_lstm.get_h() + self.src_bw_lstm.get_h())
def encode(self, src_batch, train): # Embedding lookup. src_lookup = F.parameter(self.psrc_lookup_) e_list = [] for x in src_batch: e = F.pick(src_lookup, x, 1) e = F.dropout(e, self.dropout_rate_, train) e_list.append(e) # Forward encoding self.src_fw_lstm_.reset() f_list = [] for e in e_list: f = self.src_fw_lstm_.forward(e) f = F.dropout(f, self.dropout_rate_, train) f_list.append(f) # Backward encoding self.src_bw_lstm_.reset() b_list = [] for e in reversed(e_list): b = self.src_bw_lstm_.forward(e) b = F.dropout(b, self.dropout_rate_, train) b_list.append(b) b_list.reverse() # Concatenates RNN states. fb_list = [F.concat([f_list[i], b_list[i]], 0) for i in range(len(src_batch))] self.concat_fb = F.concat(fb_list, 1) self.t_concat_fb = F.transpose(self.concat_fb) # Initializes decode states. self.wfbw_ = F.parameter(self.pwfbw_) self.whw_ = F.parameter(self.pwhw_) self.wwe_ = F.parameter(self.pwwe_) self.trg_lookup_ = F.parameter(self.ptrg_lookup_) self.whj_ = F.parameter(self.pwhj_) self.bj_ = F.parameter(self.pbj_) self.wjy_ = F.parameter(self.pwjy_) self.by_ = F.parameter(self.pby_) self.trg_lstm_.reset()
def forward(self, inputs, train): batch_size = len(inputs[0]) lookup = F.parameter(self.plookup) self.rnn1.restart() self.rnn2.restart() self.hy.reset() xs = [ F.dropout(F.pick(lookup, inputs[i], 1), DROPOUT_RATE, train) for i in range(len(inputs) - 1) ] hs1 = self.rnn1.forward(xs) for i in range(len(inputs) - 1): hs1[i] = F.dropout(hs1[i], DROPOUT_RATE, train) hs2 = self.rnn2.forward(hs1) outputs = [ self.hy.forward(F.dropout(hs2[i], DROPOUT_RATE, train)) for i in range(len(inputs) - 1) ] return outputs