Beispiel #1
0
 def reset(self, init_c = Node(), init_h = Node()):
     """Initializes internal states."""
     out_size = self._pwhh.shape()[1]
     self._wxh = F.parameter(self._pwxh)
     self._whh = F.parameter(self._pwhh)
     self._bh = F.parameter(self._pbh)
     self._c = init_c if init_c.valid() else F.zeros([out_size])
     self._h = init_h if init_h.valid() else F.zeros([out_size])
Beispiel #2
0
 def forward(self, inputs):
     batch_size = len(inputs[0])
     wlookup = F.parameter(self.pwlookup)
     wxs = F.parameter(self.pwxs)
     wsy = F.parameter(self.pwsy)
     s = F.zeros(Shape([NUM_HIDDEN_UNITS], batch_size))
     outputs = []
     for i in range(len(inputs) - 1):
         w = F.pick(wlookup, inputs[i], 1)
         x = w + s
         s = F.sigmoid(wxs @ x)
         outputs.append(wsy @ s)
     return outputs
    def encode(self, src_batch, train):
        """Encodes source sentences and prepares internal states."""
        # Embedding lookup.
        src_lookup = F.parameter(self.psrc_lookup)
        e_list = []
        for x in src_batch:
            e = F.pick(src_lookup, x, 1)
            e = F.dropout(e, self.dropout_rate, train)
            e_list.append(e)

        # Forward encoding
        self.src_fw_lstm.restart()
        f_list = []
        for e in e_list:
            f = self.src_fw_lstm.forward(e)
            f = F.dropout(f, self.dropout_rate, train)
            f_list.append(f)

        # Backward encoding
        self.src_bw_lstm.restart()
        b_list = []
        for e in reversed(e_list):
            b = self.src_bw_lstm.forward(e)
            b = F.dropout(b, self.dropout_rate, train)
            b_list.append(b)

        b_list.reverse()

        # Concatenates RNN states.
        fb_list = [f_list[i] + b_list[i] for i in range(len(src_batch))]
        self.concat_fb = F.concat(fb_list, 1)
        self.t_concat_fb = F.transpose(self.concat_fb)

        # Initializes decode states.
        embed_size = self.psrc_lookup.shape()[0]
        self.trg_lookup = F.parameter(self.ptrg_lookup)
        self.whj = F.parameter(self.pwhj)
        self.bj = F.parameter(self.pbj)
        self.wjy = F.parameter(self.pwjy)
        self.by = F.parameter(self.pby)
        self.feed = F.zeros([embed_size])
        self.trg_lstm.restart(
            self.src_fw_lstm.get_c() + self.src_bw_lstm.get_c(),
            self.src_fw_lstm.get_h() + self.src_bw_lstm.get_h())
Beispiel #4
0
    def forward(self, xs):
        x = F.concat(xs, 1)
        u = self.w @ x
        j = F.slice(u, 0, 0, self.out_size)
        f = F.sigmoid(
            F.slice(u, 0, self.out_size, 2 * self.out_size) +
            F.broadcast(self.bf, 1, len(xs)))
        r = F.sigmoid(
            F.slice(u, 0, 2 * self.out_size, 3 * self.out_size) +
            F.broadcast(self.bf, 1, len(xs)))
        c = F.zeros([self.out_size])
        hs = []
        for i in range(len(xs)):
            ji = F.slice(j, 1, i, i + 1)
            fi = F.slice(f, 1, i, i + 1)
            ri = F.slice(r, 1, i, i + 1)
            c = fi * c + (1 - fi) * ji
            hs.append(ri * F.tanh(c) + (1 - ri) * xs[i])

        return hs
Beispiel #5
0
 def restart(self):
     self.wxh = F.parameter(self.pwxh)
     self.whh = F.parameter(self.pwhh)
     self.bh = F.parameter(self.pbh)
     self.h = self.c = F.zeros([self.out_size])