예제 #1
0
    def forward(self, inputs, weights):
        x, gru_state = inputs

        # Dense layer on the concatenation of x and h.
        w1, b1, w2, b2 = weights
        y = np.dot(np.concatenate([x, gru_state], axis=-1), w1) + b1

        # Update and reset gates.
        u, r = np.split(backend.sigmoid(y), 2, axis=-1)

        # Candidate.
        c = np.dot(np.concatenate([x, r * gru_state], axis=-1), w2) + b2

        new_gru_state = u * gru_state + (1 - u) * np.tanh(c)
        return new_gru_state, new_gru_state
예제 #2
0
    def forward(self, x, weights):
        seqlen = x.shape[1]
        d_head = x.shape[2]

        x = np.reshape(x, (-1, self._n_heads, seqlen, d_head))
        x = np.transpose(x,
                         (0, 2, 1, 3))  # -> n_batch, seqlen, n_heads, d_head
        x = np.reshape(x, (-1, seqlen, self._n_heads * d_head))

        return np.dot(x, weights)
예제 #3
0
    def forward(self, x, weights):
        seqlen = x.shape[1]
        res = np.dot(x, weights)

        # n_batch, seqlen, n_heads*d_head -> n_batch, seqlen, n_heads, d_head
        res = np.reshape(res,
                         (x.shape[0], seqlen, self._n_heads, self._d_head))
        # n_batch, seqlen, n_heads, d_head -> n_batch, n_heads, seqlen, d_head
        res = np.transpose(res, (0, 2, 1, 3))
        # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
        res = np.reshape(res, (-1, seqlen, self._d_head))

        return res
예제 #4
0
    def forward(self, inputs, weights):
        x, lstm_state = inputs

        # LSTM state consists of c and h.
        c, h = np.split(lstm_state, 2, axis=-1)

        # Dense layer on the concatenation of x and h.
        w, b = weights
        y = np.dot(np.concatenate([x, h], axis=-1), w) + b

        # i = input_gate, j = new_input, f = forget_gate, o = output_gate
        i, j, f, o = np.split(y, 4, axis=-1)

        new_c = c * backend.sigmoid(f) + backend.sigmoid(i) * np.tanh(j)
        new_h = np.tanh(new_c) * backend.sigmoid(o)
        return new_h, np.concatenate([new_c, new_h], axis=-1)
예제 #5
0
파일: core.py 프로젝트: wangleiphy/trax
 def forward(self, x, weights):
   w, b = weights
   return np.dot(x, w) + b
예제 #6
0
 def forward(self, x, params=(), state=(), **kwargs):
     del kwargs
     w, b = params
     return np.dot(x, w) + b, state