Example #1
0
    def Init(shape, rng):
        """Returns orthogonalized random normal values with the given `shape`."""
        # Have at least 2 elements in shape.
        cur_shape = list(shape)
        while len(cur_shape) < 2:
            cur_shape = [1] + cur_shape

        # Flatten the input shape with the last dimension remaining.
        n_rows = 1
        for dim in cur_shape[:-1]:
            n_rows *= dim
        n_cols = cur_shape[-1]
        flat_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols)

        # Generate a random matrix
        a = random.normal(rng, flat_shape, dtype=jnp.float32)

        # Compute the qr factorization
        q, r = jnp.linalg.qr(a)

        # Make Q uniform
        d = jnp.diag(r)
        q *= jnp.sign(d)

        # Transpose and reshape back q if needed.
        if n_rows < n_cols:
            q = jnp.transpose(q)
        q = jnp.reshape(q, shape)

        # Return scaled as requested.
        return stddev * q
Example #2
0
def CombineHeadsPos(x, n_heads=1, **unused_kwargs):
    """Mix x = (x0, p0, ..., xH, pH) into (x0, ...., xH), p_combined.

  The positions are averaged as vectors.

  Args:
    x: input vector, concatenated (x0, p0, ..., xH, pH).
    n_heads: number of heads.

  Returns:
    the vector with combined xs and one with combined positions.
  """
    seqlen = x.shape[1]
    d_head = x.shape[2]
    x = np.reshape(x, (-1, n_heads, seqlen, d_head))
    x = np.transpose(x, (0, 2, 1, 3))  # -> n_batch, seqlen, n_heads, d_head
    x = np.reshape(x, (-1, seqlen, n_heads * d_head))
    head_size = int(d_head) - POS_VECTOR_SIZE
    res, positions, idx = [], [], 0
    for _ in range(n_heads):
        res.append(x[:, :, idx:idx + head_size])
        idx += head_size
        positions.append(x[:, :, idx:idx + POS_VECTOR_SIZE])
        idx += POS_VECTOR_SIZE
    combined_position = sum(positions) / float(len(positions))
    return np.concatenate(res, axis=-1), combined_position
Example #3
0
 def compute_attention_heads(x):
   batch_size = x.shape[0]
   seqlen = x.shape[1]
   # n_batch, seqlen, n_heads*d_head -> n_batch, seqlen, n_heads, d_head
   x = jnp.reshape(x, (batch_size, seqlen, n_heads, d_head))
   # n_batch, seqlen, n_heads, d_head -> n_batch, n_heads, seqlen, d_head
   x = jnp.transpose(x, (0, 2, 1, 3))
   # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
   return jnp.reshape(x, (-1, seqlen, d_head))
Example #4
0
 def create_weights_unbatched(self, input_signature, rng):
   d_model = input_signature[0].shape[-1]
   d_kv_antecedent = input_signature[1].shape[-1]
   rng_q, rng_k, rng_v, rng_o = jax.random.split(rng, 4)
   w_q = self._kernel_initializer((d_model, self.d_qk), rng_q)
   w_k = self._kernel_initializer((d_kv_antecedent, self.d_qk), rng_k)
   w_v = self._kernel_initializer((d_kv_antecedent, self.d_v), rng_v)
   w_o = np.transpose(self._kernel_initializer((d_model, self.d_v), rng_o))
   return (w_q, w_k, w_v, w_o)
Example #5
0
  def forward(self, x, weights):
    seqlen = x.shape[1]
    d_head = x.shape[2]

    x = np.reshape(x, (-1, self._n_heads, seqlen, d_head))
    x = np.transpose(x, (0, 2, 1, 3))  # -> n_batch, seqlen, n_heads, d_head
    x = np.reshape(x, (-1, seqlen, self._n_heads * d_head))

    return np.dot(x, weights)
Example #6
0
  def forward(self, x, weights):
    seqlen = x.shape[1]
    res = np.dot(x, weights)

    # n_batch, seqlen, n_heads*d_head -> n_batch, seqlen, n_heads, d_head
    res = np.reshape(res, (x.shape[0], seqlen, self._n_heads, self._d_head))
    # n_batch, seqlen, n_heads, d_head -> n_batch, n_heads, seqlen, d_head
    res = np.transpose(res, (0, 2, 1, 3))
    # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
    res = np.reshape(res, (-1, seqlen, self._d_head))

    return res
Example #7
0
 def create_weights_unbatched(self, input_signature, rng):
     d_model = input_signature.shape[-1]
     rng_q, rng_k, rng_v, rng_o = jax.random.split(rng, 4)
     w_q = self._kernel_initializer((d_model, self.d_qk), rng_q)
     if not self.share_qk:
         w_k = self._kernel_initializer((d_model, self.d_qk), rng_k)
     w_v = self._kernel_initializer((d_model, self.d_v), rng_v)
     w_o = np.transpose(self._kernel_initializer((d_model, self.d_v),
                                                 rng_o))
     if self.share_qk:
         return (w_q, w_v, w_o)
     else:
         return (w_q, w_k, w_v, w_o)
 def forward(self, inp, weights):
   """Reshape input to have heads dimension and concatenate positions there."""
   x = inp[0]
   n_batches, seqlen = x.shape[0], x.shape[1]
   d_head = x.shape[-1] // self._n_heads
   res = np.reshape(x, (n_batches, seqlen, self._n_heads, d_head))
   res = np.transpose(res, (0, 2, 1, 3))  # (batch, heads, len, depth)
   if self._n_pos == 1:  # Just one position given, tile into each head.
     pos_shape = list(res.shape)[:-1] + [inp[1].shape[-1]]
     pos = inp[1][:, None, :, :] + np.zeros(pos_shape)  # Add 0 to broadcast.
   else:  # As many positions as heads, concatenate them in.
     pos = [p[:, None, :, :] for p in inp[1:]]
     pos = np.concatenate(pos, axis=1)
   res = np.concatenate([res, pos], axis=-1)
   # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
   res = np.reshape(res, (-1, seqlen, d_head + POS_VECTOR_SIZE))
   return res
Example #9
0
 def JoinHeads(x):  # pylint: disable=invalid-name
     return np.reshape(np.transpose(x, (0, 2, 1, 3)),
                       (nbatch, -1, n_heads * d_head))
Example #10
0
 def SplitHeads(x):
     return np.transpose(np.reshape(x, (nbatch, -1, n_heads, d_head)),
                         (0, 2, 1, 3))
Example #11
0
 def compute_attention_output(x):
   seqlen = x.shape[1]
   x = jnp.reshape(x, (-1, n_heads, seqlen, d_head))
   x = jnp.transpose(x, (0, 2, 1, 3))  # -> n_batch, seqlen, n_heads, d_head
   return jnp.reshape(x, (-1, seqlen, n_heads * d_head))