def forward(self, x, weights): if self._use_bias: if not isinstance(weights, (tuple, list)): raise ValueError(f'Weights should be a (w, b) tuple or list; ' f'instead got: {weights}') w, b = weights return jnp.dot(x, w) + b # Affine map. else: w = weights return jnp.dot(x, w) # Linear map.
def forward(self, inputs, weights): x, gru_state = inputs # Dense layer on the concatenation of x and h. w1, b1, w2, b2 = weights y = jnp.dot(jnp.concatenate([x, gru_state], axis=-1), w1) + b1 # Update and reset gates. u, r = jnp.split(math.sigmoid(y), 2, axis=-1) # Candidate. c = jnp.dot(jnp.concatenate([x, r * gru_state], axis=-1), w2) + b2 new_gru_state = u * gru_state + (1 - u) * jnp.tanh(c) return new_gru_state, new_gru_state
def forward(self, x, weights): if len(weights) != 2: raise ValueError( f'Weights has length {len(weights)}; should instead ' f'have two elements: w, b.') w, b = weights return jnp.dot(x, w) + b
def forward(self, x, weights): seqlen = x.shape[1] d_head = x.shape[2] x = np.reshape(x, (-1, self._n_heads, seqlen, d_head)) x = np.transpose(x, (0, 2, 1, 3)) # -> n_batch, seqlen, n_heads, d_head x = np.reshape(x, (-1, seqlen, self._n_heads * d_head)) return np.dot(x, weights)
def forward(self, x): """Executes this layer as part of a forward pass through the model. Args: x: Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor of same shape and dtype as the input, except the final dimension is the layer's `n_units` value. """ if self._use_bias: if not isinstance(self.weights, (tuple, list)): raise ValueError(f'Weights should be a (w, b) tuple or list; ' f'instead got: {self.weights}') w, b = self.weights return jnp.dot(x, w) + b # Affine map. else: w = self.weights return jnp.dot(x, w) # Linear map.
def forward(self, x, weights): seqlen = x.shape[1] res = np.dot(x, weights) # n_batch, seqlen, n_heads*d_head -> n_batch, seqlen, n_heads, d_head res = np.reshape(res, (x.shape[0], seqlen, self._n_heads, self._d_head)) # n_batch, seqlen, n_heads, d_head -> n_batch, n_heads, seqlen, d_head res = np.transpose(res, (0, 2, 1, 3)) # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head res = np.reshape(res, (-1, seqlen, self._d_head)) return res
def forward(self, inputs, weights): x, lstm_state = inputs # LSTM state consists of c and h. c, h = jnp.split(lstm_state, 2, axis=-1) # Dense layer on the concatenation of x and h. w, b = weights y = jnp.dot(jnp.concatenate([x, h], axis=-1), w) + b # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = jnp.split(y, 4, axis=-1) new_c = c * math.sigmoid(f) + math.sigmoid(i) * jnp.tanh(j) new_h = jnp.tanh(new_c) * math.sigmoid(o) return new_h, jnp.concatenate([new_c, new_h], axis=-1)
def forward(self, x, weights): w, b = weights return np.dot(x, w) + b