コード例 #1
0
ファイル: efattentional.py プロジェクト: philip30/chainn
 def __init__(self, I, E, H, depth, dropout_ratio):
     self.IE = L.EmbedID(I, E)
     self.EF = StackLSTM(E, H, depth, dropout_ratio)
     self.EB = StackLSTM(E, H, depth, dropout_ratio)
     self.AE = L.Linear(2*H, H)
     self.H  = H
     super(Encoder, self).__init__(self.IE, self.EF, self.EB, self.AE)
コード例 #2
0
ファイル: recurrent_lstm.py プロジェクト: philip30/chainn
class RecurrentLSTM(ChainnBasicModel):
    name = "lstm"

    def _construct_model(self, input, output, hidden, depth, embed):
        assert(depth >= 1)
        self.embed  = L.EmbedID(input, embed)
        self.inner  = StackLSTM(embed, hidden, depth, self._dropout)
        self.output = L.Linear(hidden, output)
        return [self.embed, self.inner, self.output]

    def reset_state(self, *args, **kwargs):
        self.inner.reset_state()
    
    def __call__(self, word, ref=None, is_train=False):
        return F.softmax(self.output(self._activation(self.inner(self.embed(word), is_train=is_train))))
コード例 #3
0
ファイル: efattentional.py プロジェクト: philip30/chainn
 def __init__(self, O, E, H, depth, dropout_ratio):
     self.DF = StackLSTM(E, H, depth, dropout_ratio)
     self.WS = L.Linear(H, O)
     self.WC = L.Linear(2*H, H)
     self.OE = L.EmbedID(O, E)
     self.HE = L.Linear(H, E)
     super(Decoder, self).__init__(self.DF, self.WS, self.WC, self.OE, self.HE)
コード例 #4
0
ファイル: efattentional.py プロジェクト: philip30/chainn
class Decoder(ChainList):
    def __init__(self, O, E, H, depth, dropout_ratio):
        self.DF = StackLSTM(E, H, depth, dropout_ratio)
        self.WS = L.Linear(H, O)
        self.WC = L.Linear(2*H, H)
        self.OE = L.EmbedID(O, E)
        self.HE = L.Linear(H, E)
        super(Decoder, self).__init__(self.DF, self.WS, self.WC, self.OE, self.HE)
    
    def __call__(self, s, a, h):
        c = F.reshape(F.batch_matmul(a, s, transa=True), h.data.shape)
        ht = F.tanh(self.WC(F.concat((h, c), axis=1)))
        return self.WS(ht)

    # Conceive the first state of decoder based on the last state of encoder
    def reset(self, s, is_train=False):
        self.DF.reset_state()
        return self.DF(self.HE(s), is_train=is_train)

    def update(self, wt, is_train=False):
        return self.DF(self.OE(wt), is_train=is_train)
コード例 #5
0
ファイル: efattentional.py プロジェクト: philip30/chainn
class Encoder(ChainList):
    def __init__(self, I, E, H, depth, dropout_ratio):
        self.IE = L.EmbedID(I, E)
        self.EF = StackLSTM(E, H, depth, dropout_ratio)
        self.EB = StackLSTM(E, H, depth, dropout_ratio)
        self.AE = L.Linear(2*H, H)
        self.H  = H
        super(Encoder, self).__init__(self.IE, self.EF, self.EB, self.AE)

    def __call__(self, src, is_train=False, xp=np):
        # Some namings
        B  = len(src)      # Batch Size
        N  = len(src[0])   # length of source
        H  = self.H
        src_col = lambda x: Variable(self.xp.array([src[i][x] for i in range(B)], dtype=np.int32))
        embed   = lambda e, x: e(self.IE(x), is_train=is_train)
        bi_rnn  = lambda x, y: self.AE(F.concat((x[0], y[1]), axis=1))
        concat_source = lambda S, s: s if S is None else F.concat((S, s), axis=2)
        # State Reset
        self.EF.reset_state()
        self.EB.reset_state()
       
        # Forward + backward encoding
        s = []
        for j in range(N):
            s.append((
                embed(self.EF, src_col(j)),
                embed(self.EB, src_col(-j-1))
            ))
        
        # Joining the encoding data together
        S = None
        for j in range(N):
            s_j = bi_rnn(s[j], s[-j-1])
            S = concat_source(S, F.reshape(s_j, (B, H, 1)))
        S = F.swapaxes(S, 1, 2)
        return S, s_j
コード例 #6
0
ファイル: recurrent_lstm.py プロジェクト: philip30/chainn
 def _construct_model(self, input, output, hidden, depth, embed):
     assert(depth >= 1)
     self.embed  = L.EmbedID(input, embed)
     self.inner  = StackLSTM(embed, hidden, depth, self._dropout)
     self.output = L.Linear(hidden, output)
     return [self.embed, self.inner, self.output]