def restart(self, init_c = Node(), init_h = Node()): """Initializes internal states.""" out_size = self.pwhh.shape()[1] self.wxh = F.parameter(self.pwxh) self.whh = F.parameter(self.pwhh) self.bh = F.parameter(self.pbh) self.c = init_c if init_c.valid() else F.zeros([out_size]) self.h = init_h if init_h.valid() else F.zeros([out_size])
def forward(self, inputs): batch_size = len(inputs[0]) wlookup = F.parameter(self.pwlookup) wxs = F.parameter(self.pwxs) wsy = F.parameter(self.pwsy) s = F.zeros(Shape([NUM_HIDDEN_UNITS], batch_size)) outputs = [] for i in range(len(inputs) - 1): w = F.pick(wlookup, inputs[i], 1) x = w + s s = F.sigmoid(wxs @ x) outputs.append(wsy @ s) return outputs
def encode(self, src_batch, train): """Encodes source sentences and prepares internal states.""" # Embedding lookup. src_lookup = F.parameter(self.psrc_lookup) e_list = [] for x in src_batch: e = F.pick(src_lookup, x, 1) e = F.dropout(e, self.dropout_rate, train) e_list.append(e) # Forward encoding self.src_fw_lstm.restart() f_list = [] for e in e_list: f = self.src_fw_lstm.forward(e) f = F.dropout(f, self.dropout_rate, train) f_list.append(f) # Backward encoding self.src_bw_lstm.restart() b_list = [] for e in reversed(e_list): b = self.src_bw_lstm.forward(e) b = F.dropout(b, self.dropout_rate, train) b_list.append(b) b_list.reverse() # Concatenates RNN states. fb_list = [f_list[i] + b_list[i] for i in range(len(src_batch))] self.concat_fb = F.concat(fb_list, 1) self.t_concat_fb = F.transpose(self.concat_fb) # Initializes decode states. embed_size = self.psrc_lookup.shape()[0] self.trg_lookup = F.parameter(self.ptrg_lookup) self.whj = F.parameter(self.pwhj) self.bj = F.parameter(self.pbj) self.wjy = F.parameter(self.pwjy) self.by = F.parameter(self.pby) self.feed = F.zeros([embed_size]) self.trg_lstm.restart( self.src_fw_lstm.get_c() + self.src_bw_lstm.get_c(), self.src_fw_lstm.get_h() + self.src_bw_lstm.get_h())
def encode(self, src_batch, train): # Embedding lookup. src_lookup = F.parameter(self.psrc_lookup_) e_list = [] for x in src_batch: e = F.pick(src_lookup, x, 1) e = F.dropout(e, self.dropout_rate_, train) e_list.append(e) # Forward encoding self.src_fw_lstm_.init() f_list = [] for e in e_list: f = self.src_fw_lstm_.forward(e) f = F.dropout(f, self.dropout_rate_, train) f_list.append(f) # Backward encoding self.src_bw_lstm_.init() b_list = [] for e in reversed(e_list): b = self.src_bw_lstm_.forward(e) b = F.dropout(b, self.dropout_rate_, train) b_list.append(b) b_list.reverse() # Concatenates RNN states. fb_list = [f_list[i] + b_list[i] for i in range(len(src_batch))] self.concat_fb_ = F.concat(fb_list, 1) self.t_concat_fb_ = F.transpose(self.concat_fb_) # Initializes decode states. self.trg_lookup_ = F.parameter(self.ptrg_lookup_) self.whj_ = F.parameter(self.pwhj_) self.bj_ = F.parameter(self.pbj_) self.wjy_ = F.parameter(self.pwjy_) self.by_ = F.parameter(self.pby_) self.feed_ = F.zeros([self.embed_size_]) self.trg_lstm_.init( self.src_fw_lstm_.get_c() + self.src_bw_lstm_.get_c(), self.src_fw_lstm_.get_h() + self.src_bw_lstm_.get_h())
def forward(self, xs): x = F.concat(xs, 1) u = self.w_ @ x j = F.slice(u, 0, 0, self.out_size_) f = F.sigmoid( F.slice(u, 0, self.out_size_, 2 * self.out_size_) + F.broadcast(self.bf_, 1, len(xs))) r = F.sigmoid( F.slice(u, 0, 2 * self.out_size_, 3 * self.out_size_) + F.broadcast(self.bf_, 1, len(xs))) c = F.zeros([self.out_size_]) hs = [] for i in range(len(xs)): ji = F.slice(j, 1, i, i + 1) fi = F.slice(f, 1, i, i + 1) ri = F.slice(r, 1, i, i + 1) c = fi * c + (1 - fi) * ji hs.append(ri * F.tanh(c) + (1 - ri) * xs[i]) return hs
def init(self, init_c=Node(), init_h=Node()): self.wxh_ = F.parameter(self.pwxh_) self.whh_ = F.parameter(self.pwhh_) self.bh_ = F.parameter(self.pbh_) self.c_ = init_c if init_c.valid() else F.zeros([self.out_size_]) self.h_ = init_h if init_h.valid() else F.zeros([self.out_size_])
def restart(self): self.wxh = F.parameter(self.pwxh) self.whh = F.parameter(self.pwhh) self.bh = F.parameter(self.pbh) self.h = self.c = F.zeros([self.out_size])
def init(self): self.wxh_ = F.parameter(self.pwxh_) self.whh_ = F.parameter(self.pwhh_) self.bh_ = F.parameter(self.pbh_) self.h_ = self.c_ = F.zeros([self.out_size_])