Exemple #1
 def forward(self, x, weights):
     if self._use_bias:
         if not isinstance(weights, (tuple, list)):
             raise ValueError(f'Weights should be a (w, b) tuple or list; '
                              f'instead got: {weights}')
         w, b = weights
         return jnp.dot(x, w) + b  # Affine map.
         w = weights
         return jnp.dot(x, w)  # Linear map.
Exemple #2
    def forward(self, inputs, weights):
        x, gru_state = inputs

        # Dense layer on the concatenation of x and h.
        w1, b1, w2, b2 = weights
        y = jnp.dot(jnp.concatenate([x, gru_state], axis=-1), w1) + b1

        # Update and reset gates.
        u, r = jnp.split(math.sigmoid(y), 2, axis=-1)

        # Candidate.
        c = jnp.dot(jnp.concatenate([x, r * gru_state], axis=-1), w2) + b2

        new_gru_state = u * gru_state + (1 - u) * jnp.tanh(c)
        return new_gru_state, new_gru_state
Exemple #3
 def forward(self, x, weights):
     if len(weights) != 2:
         raise ValueError(
             f'Weights has length {len(weights)}; should instead '
             f'have two elements: w, b.')
     w, b = weights
     return jnp.dot(x, w) + b
Exemple #4
  def forward(self, x, weights):
    seqlen = x.shape[1]
    d_head = x.shape[2]

    x = np.reshape(x, (-1, self._n_heads, seqlen, d_head))
    x = np.transpose(x, (0, 2, 1, 3))  # -> n_batch, seqlen, n_heads, d_head
    x = np.reshape(x, (-1, seqlen, self._n_heads * d_head))

    return np.dot(x, weights)
Exemple #5
  def forward(self, x):
    """Executes this layer as part of a forward pass through the model.

      x: Tensor of same shape and dtype as the input signature used to
          initialize this layer.

      Tensor of same shape and dtype as the input, except the final dimension
      is the layer's `n_units` value.
    if self._use_bias:
      if not isinstance(self.weights, (tuple, list)):
        raise ValueError(f'Weights should be a (w, b) tuple or list; '
                         f'instead got: {self.weights}')
      w, b = self.weights
      return jnp.dot(x, w) + b  # Affine map.
      w = self.weights
      return jnp.dot(x, w)  # Linear map.
Exemple #6
  def forward(self, x, weights):
    seqlen = x.shape[1]
    res = np.dot(x, weights)

    # n_batch, seqlen, n_heads*d_head -> n_batch, seqlen, n_heads, d_head
    res = np.reshape(res, (x.shape[0], seqlen, self._n_heads, self._d_head))
    # n_batch, seqlen, n_heads, d_head -> n_batch, n_heads, seqlen, d_head
    res = np.transpose(res, (0, 2, 1, 3))
    # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
    res = np.reshape(res, (-1, seqlen, self._d_head))

    return res
Exemple #7
    def forward(self, inputs, weights):
        x, lstm_state = inputs

        # LSTM state consists of c and h.
        c, h = jnp.split(lstm_state, 2, axis=-1)

        # Dense layer on the concatenation of x and h.
        w, b = weights
        y = jnp.dot(jnp.concatenate([x, h], axis=-1), w) + b

        # i = input_gate, j = new_input, f = forget_gate, o = output_gate
        i, j, f, o = jnp.split(y, 4, axis=-1)

        new_c = c * math.sigmoid(f) + math.sigmoid(i) * jnp.tanh(j)
        new_h = jnp.tanh(new_c) * math.sigmoid(o)
        return new_h, jnp.concatenate([new_c, new_h], axis=-1)
Exemple #8
 def forward(self, x, weights):
     w, b = weights
     return np.dot(x, w) + b