Example #1
0
 def forward(self, x, weights):
     if self._use_bias:
         if not isinstance(weights, (tuple, list)):
             raise ValueError(f'Weights should be a (w, b) tuple or list; '
                              f'instead got: {weights}')
         w, b = weights
         return jnp.dot(x, w) + b  # Affine map.
     else:
         w = weights
         return jnp.dot(x, w)  # Linear map.
Example #2
0
    def forward(self, inputs, weights):
        x, gru_state = inputs

        # Dense layer on the concatenation of x and h.
        w1, b1, w2, b2 = weights
        y = jnp.dot(jnp.concatenate([x, gru_state], axis=-1), w1) + b1

        # Update and reset gates.
        u, r = jnp.split(math.sigmoid(y), 2, axis=-1)

        # Candidate.
        c = jnp.dot(jnp.concatenate([x, r * gru_state], axis=-1), w2) + b2

        new_gru_state = u * gru_state + (1 - u) * jnp.tanh(c)
        return new_gru_state, new_gru_state
Example #3
0
 def forward(self, x, weights):
     if len(weights) != 2:
         raise ValueError(
             f'Weights has length {len(weights)}; should instead '
             f'have two elements: w, b.')
     w, b = weights
     return jnp.dot(x, w) + b
Example #4
0
  def forward(self, x, weights):
    seqlen = x.shape[1]
    d_head = x.shape[2]

    x = np.reshape(x, (-1, self._n_heads, seqlen, d_head))
    x = np.transpose(x, (0, 2, 1, 3))  # -> n_batch, seqlen, n_heads, d_head
    x = np.reshape(x, (-1, seqlen, self._n_heads * d_head))

    return np.dot(x, weights)
Example #5
0
  def forward(self, x):
    """Executes this layer as part of a forward pass through the model.

    Args:
      x: Tensor of same shape and dtype as the input signature used to
          initialize this layer.

    Returns:
      Tensor of same shape and dtype as the input, except the final dimension
      is the layer's `n_units` value.
    """
    if self._use_bias:
      if not isinstance(self.weights, (tuple, list)):
        raise ValueError(f'Weights should be a (w, b) tuple or list; '
                         f'instead got: {self.weights}')
      w, b = self.weights
      return jnp.dot(x, w) + b  # Affine map.
    else:
      w = self.weights
      return jnp.dot(x, w)  # Linear map.
Example #6
0
  def forward(self, x, weights):
    seqlen = x.shape[1]
    res = np.dot(x, weights)

    # n_batch, seqlen, n_heads*d_head -> n_batch, seqlen, n_heads, d_head
    res = np.reshape(res, (x.shape[0], seqlen, self._n_heads, self._d_head))
    # n_batch, seqlen, n_heads, d_head -> n_batch, n_heads, seqlen, d_head
    res = np.transpose(res, (0, 2, 1, 3))
    # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
    res = np.reshape(res, (-1, seqlen, self._d_head))

    return res
Example #7
0
    def forward(self, inputs, weights):
        x, lstm_state = inputs

        # LSTM state consists of c and h.
        c, h = jnp.split(lstm_state, 2, axis=-1)

        # Dense layer on the concatenation of x and h.
        w, b = weights
        y = jnp.dot(jnp.concatenate([x, h], axis=-1), w) + b

        # i = input_gate, j = new_input, f = forget_gate, o = output_gate
        i, j, f, o = jnp.split(y, 4, axis=-1)

        new_c = c * math.sigmoid(f) + math.sigmoid(i) * jnp.tanh(j)
        new_h = jnp.tanh(new_c) * math.sigmoid(o)
        return new_h, jnp.concatenate([new_c, new_h], axis=-1)
Example #8
0
 def forward(self, x, weights):
     w, b = weights
     return np.dot(x, w) + b