Ejemplo n.º 1
0
 def decode_step(self, trg_words, train):
     """One step decoding."""
     x = F.pick(self.trg_lookup, trg_words, 1)
     x = F.dropout(x, self.dropout_rate, train)
     h = self.trg_lstm.forward(x)
     h = F.dropout(h, self.dropout_rate, train)
     return self.why @ h + self.by
Ejemplo n.º 2
0
 def decode_step(self, trg_words, train):
     """One step decoding."""
     e = F.pick(self.trg_lookup, trg_words, 1)
     e = F.dropout(e, self.dropout_rate, train)
     h = self.trg_lstm.forward(F.concat([e, self.feed], 0))
     h = F.dropout(h, self.dropout_rate, train)
     atten_probs = F.softmax(self.t_concat_fb @ h, 0)
     c = self.concat_fb @ atten_probs
     self.feed = F.tanh(self.whj @ F.concat([h, c], 0) + self.bj)
     return self.wjy @ self.feed + self.by
Ejemplo n.º 3
0
 def forward(self, inputs):
     batch_size = len(inputs[0])
     wlookup = F.parameter(self.pwlookup)
     wxs = F.parameter(self.pwxs)
     wsy = F.parameter(self.pwsy)
     s = F.zeros(Shape([NUM_HIDDEN_UNITS], batch_size))
     outputs = []
     for i in range(len(inputs) - 1):
         w = F.pick(wlookup, inputs[i], 1)
         x = w + s
         s = F.sigmoid(wxs @ x)
         outputs.append(wsy @ s)
     return outputs
Ejemplo n.º 4
0
    def encode(self, src_batch, train):
        """Encodes source sentences and prepares internal states."""
        # Reversed encoding.
        src_lookup = F.parameter(self.psrc_lookup)
        self.src_lstm.restart()
        for it in src_batch:
            x = F.pick(src_lookup, it, 1)
            x = F.dropout(x, self.dropout_rate, train)
            self.src_lstm.forward(x)

        # Initializes decoder states.
        self.trg_lookup = F.parameter(self.ptrg_lookup)
        self.why = F.parameter(self.pwhy)
        self.by = F.parameter(self.pby)
        self.trg_lstm.restart(self.src_lstm.get_c(), self.src_lstm.get_h())
Ejemplo n.º 5
0
    def decode_step(self, trg_words, train):
        sentence_len = self.concat_fb.shape()[1]

        b = self.whw_ @ self.trg_lstm_.get_h()
        b = F.reshape(b, Shape([1, b.shape()[0]]))
        b = F.broadcast(b, 0, sentence_len)
        x = F.tanh(self.t_concat_fb @ self.wfbw_ + b)
        atten_prob = F.softmax(x @ self.wwe_, 0)
        c = self.concat_fb @ atten_prob

        e = F.pick(self.trg_lookup_, trg_words, 1)
        e = F.dropout(e, self.dropout_rate_, train)

        h = self.trg_lstm_.forward(F.concat([e, c], 0))
        h = F.dropout(h, self.dropout_rate_, train)
        j = F.tanh(self.whj_ @ h + self.bj_)
        return self.wjy_ @ j + self.by_
Ejemplo n.º 6
0
    def forward(self, inputs, train):
        batch_size = len(inputs[0])
        lookup = F.parameter(self.plookup)
        self.rnn1.restart()
        self.rnn2.restart()
        self.hy.reset()

        outputs = []
        for i in range(len(inputs) - 1):
            x = F.pick(lookup, inputs[i], 1)
            x = F.dropout(x, DROPOUT_RATE, train)
            h1 = self.rnn1.forward(x)
            h1 = F.dropout(h1, DROPOUT_RATE, train)
            h2 = self.rnn2.forward(h1)
            h2 = F.dropout(h2, DROPOUT_RATE, train)
            outputs.append(self.hy.forward(h2))

        return outputs
Ejemplo n.º 7
0
    def encode(self, src_batch, train):
        """Encodes source sentences and prepares internal states."""
        # Embedding lookup.
        src_lookup = F.parameter(self.psrc_lookup)
        e_list = []
        for x in src_batch:
            e = F.pick(src_lookup, x, 1)
            e = F.dropout(e, self.dropout_rate, train)
            e_list.append(e)

        # Forward encoding
        self.src_fw_lstm.restart()
        f_list = []
        for e in e_list:
            f = self.src_fw_lstm.forward(e)
            f = F.dropout(f, self.dropout_rate, train)
            f_list.append(f)

        # Backward encoding
        self.src_bw_lstm.restart()
        b_list = []
        for e in reversed(e_list):
            b = self.src_bw_lstm.forward(e)
            b = F.dropout(b, self.dropout_rate, train)
            b_list.append(b)

        b_list.reverse()

        # Concatenates RNN states.
        fb_list = [f_list[i] + b_list[i] for i in range(len(src_batch))]
        self.concat_fb = F.concat(fb_list, 1)
        self.t_concat_fb = F.transpose(self.concat_fb)

        # Initializes decode states.
        embed_size = self.psrc_lookup.shape()[0]
        self.trg_lookup = F.parameter(self.ptrg_lookup)
        self.whj = F.parameter(self.pwhj)
        self.bj = F.parameter(self.pbj)
        self.wjy = F.parameter(self.pwjy)
        self.by = F.parameter(self.pby)
        self.feed = F.zeros([embed_size])
        self.trg_lstm.restart(
            self.src_fw_lstm.get_c() + self.src_bw_lstm.get_c(),
            self.src_fw_lstm.get_h() + self.src_bw_lstm.get_h())
Ejemplo n.º 8
0
    def encode(self, src_batch, train):
        # Embedding lookup.
        src_lookup = F.parameter(self.psrc_lookup_)
        e_list = []
        for x in src_batch:
            e = F.pick(src_lookup, x, 1)
            e = F.dropout(e, self.dropout_rate_, train)
            e_list.append(e)

        # Forward encoding
        self.src_fw_lstm_.reset()
        f_list = []
        for e in e_list:
            f = self.src_fw_lstm_.forward(e)
            f = F.dropout(f, self.dropout_rate_, train)
            f_list.append(f)

        # Backward encoding
        self.src_bw_lstm_.reset()
        b_list = []
        for e in reversed(e_list):
            b = self.src_bw_lstm_.forward(e)
            b = F.dropout(b, self.dropout_rate_, train)
            b_list.append(b)
        b_list.reverse()

        # Concatenates RNN states.
        fb_list = [F.concat([f_list[i], b_list[i]], 0) for i in range(len(src_batch))]
        self.concat_fb = F.concat(fb_list, 1)
        self.t_concat_fb = F.transpose(self.concat_fb)

        # Initializes decode states.
        self.wfbw_ = F.parameter(self.pwfbw_)
        self.whw_ = F.parameter(self.pwhw_)
        self.wwe_ = F.parameter(self.pwwe_)
        self.trg_lookup_ = F.parameter(self.ptrg_lookup_)
        self.whj_ = F.parameter(self.pwhj_)
        self.bj_ = F.parameter(self.pbj_)
        self.wjy_ = F.parameter(self.pwjy_)
        self.by_ = F.parameter(self.pby_)
        self.trg_lstm_.reset()
Ejemplo n.º 9
0
    def forward(self, inputs, train):
        batch_size = len(inputs[0])
        lookup = F.parameter(self.plookup)
        self.rnn1.restart()
        self.rnn2.restart()
        self.hy.reset()

        xs = [
            F.dropout(F.pick(lookup, inputs[i], 1), DROPOUT_RATE, train)
            for i in range(len(inputs) - 1)
        ]
        hs1 = self.rnn1.forward(xs)
        for i in range(len(inputs) - 1):
            hs1[i] = F.dropout(hs1[i], DROPOUT_RATE, train)
        hs2 = self.rnn2.forward(hs1)
        outputs = [
            self.hy.forward(F.dropout(hs2[i], DROPOUT_RATE, train))
            for i in range(len(inputs) - 1)
        ]

        return outputs