Esempio n. 1
0
    def Init(shape, rng):
        """Returns orthogonalized random normal values with the given `shape`."""
        # Have at least 2 elements in shape.
        cur_shape = list(shape)
        while len(cur_shape) < 2:
            cur_shape = [1] + cur_shape

        # Flatten the input shape with the last dimension remaining.
        n_rows = 1
        for dim in cur_shape[:-1]:
            n_rows *= dim
        n_cols = cur_shape[-1]
        flat_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols)

        # Generate a random matrix
        a = random.normal(rng, flat_shape, dtype=np.float32)

        # Compute the qr factorization
        q, r = np.linalg.qr(a)

        # Make Q uniform
        d = np.diag(r)
        q *= np.sign(d)

        # Transpose and reshape back q if needed.
        if n_rows < n_cols:
            q = np.transpose(q)
        q = np.reshape(q, shape)

        # Return scaled as requested.
        return stddev * q
def CombineHeadsPos(x, n_heads=1, **unused_kwargs):
  """Mix x = (x0, p0, ..., xH, pH) into (x0, ...., xH), p_combined.

  The positions are averaged as vectors.

  Args:
    x: input vector, concatenated (x0, p0, ..., xH, pH).
    n_heads: number of heads.

  Returns:
    the vector with combined xs and one with combined positions.
  """
  seqlen = x.shape[1]
  d_head = x.shape[2]
  x = np.reshape(x, (-1, n_heads, seqlen, d_head))
  x = np.transpose(x, (0, 2, 1, 3))  # -> n_batch, seqlen, n_heads, d_head
  x = np.reshape(x, (-1, seqlen, n_heads * d_head))
  head_size = int(d_head) - POS_VECTOR_SIZE
  res, positions, idx = [], [], 0
  for _ in range(n_heads):
    res.append(x[:, :, idx:idx+head_size])
    idx += head_size
    positions.append(x[:, :, idx:idx+POS_VECTOR_SIZE])
    idx += POS_VECTOR_SIZE
  combined_position = sum(positions) / float(len(positions))
  return np.concatenate(res, axis=-1), combined_position
Esempio n. 3
0
    def forward(self, x, weights):
        seqlen = x.shape[1]
        d_head = x.shape[2]

        x = np.reshape(x, (-1, self._n_heads, seqlen, d_head))
        x = np.transpose(x,
                         (0, 2, 1, 3))  # -> n_batch, seqlen, n_heads, d_head
        x = np.reshape(x, (-1, seqlen, self._n_heads * d_head))

        return np.dot(x, weights)
Esempio n. 4
0
    def forward(self, x, weights):
        seqlen = x.shape[1]
        res = np.dot(x, weights)

        # n_batch, seqlen, n_heads*d_head -> n_batch, seqlen, n_heads, d_head
        res = np.reshape(res,
                         (x.shape[0], seqlen, self._n_heads, self._d_head))
        # n_batch, seqlen, n_heads, d_head -> n_batch, n_heads, seqlen, d_head
        res = np.transpose(res, (0, 2, 1, 3))
        # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
        res = np.reshape(res, (-1, seqlen, self._d_head))

        return res
 def forward(self, inp, weights):
   """Reshape input to have heads dimension and concatenate positions there."""
   x = inp[0]
   n_batches, seqlen = x.shape[0], x.shape[1]
   d_head = x.shape[-1] // self._n_heads
   res = np.reshape(x, (n_batches, seqlen, self._n_heads, d_head))
   res = np.transpose(res, (0, 2, 1, 3))  # (batch, heads, len, depth)
   if self._n_pos == 1:  # Just one position given, tile into each head.
     pos_shape = list(res.shape)[:-1] + [inp[1].shape[-1]]
     pos = inp[1][:, None, :, :] + np.zeros(pos_shape)  # Add 0 to broadcast.
   else:  # As many positions as heads, concatenate them in.
     pos = [p[:, None, :, :] for p in inp[1:]]
     pos = np.concatenate(pos, axis=1)
   res = np.concatenate([res, pos], axis=-1)
   # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
   res = np.reshape(res, (-1, seqlen, d_head + POS_VECTOR_SIZE))
   return res
Esempio n. 6
0
    def _sample_rotation(self, shape, vecs, rng):
        """Samples a rotation matrix, either randomly or based on `vecs`."""

        if not self._data_rotation:
            return jax.random.normal(rng, shape).astype('float32')

        assert len(shape) == 3
        unused_n_dim, n_hashes, r_div_2 = shape

        assert len(vecs.shape) == 2
        n_vecs = vecs.shape[0]

        rng1, rng2 = backend.random.split(rng, num=2)

        # We need to sample 2 * n_hashes * r_div_2 vectors from `vecs` at random.
        num_needed = 2 * n_hashes * r_div_2
        if n_vecs < num_needed:
            # shape = (n_hashes, r_div_2)
            random_idxs_1 = jax.random.randint(rng1, (n_hashes, r_div_2), 0,
                                               n_vecs)
            random_idxs_2 = jax.random.randint(rng2, (n_hashes, r_div_2), 0,
                                               n_vecs)
        else:
            # Sample without replacement.
            shuffled_indices = jax.random.shuffle(rng1, np.arange(n_vecs))
            random_idxs = np.reshape(shuffled_indices[:num_needed],
                                     (2, n_hashes, r_div_2))
            random_idxs_1 = random_idxs[0]
            random_idxs_2 = random_idxs[1]

        if self._data_rotation_farthest:
            # shape = (n_hashes * r_div_2, )
            random_idxs_1 = np.reshape(random_idxs_1, (-1, ))
            random_vecs_1 = vecs[random_idxs_1]

            # Sample candidates for vec2s.
            rng, subrng = backend.random.split(rng)
            # shape = (self._data_rotation_farthest_num, n_hashes * r_div_2)
            candidate_idxs_2 = jax.random.randint(
                subrng, (self._data_rotation_farthest_num, n_hashes * r_div_2),
                0, n_vecs)
            candidate_vecs_2 = vecs[candidate_idxs_2]
            # shape = candidate_idxs_2.shape
            distances = -np.abs(
                np.einsum('hd,chd->ch', random_vecs_1, candidate_vecs_2))
            # shape = (n_hashes * r_div_2,)
            farthest_idxs = np.argmax(distances, axis=0)
            # candidate_vecs_2.shape
            random_vecs_2 = candidate_vecs_2[farthest_idxs,
                                             np.arange(n_hashes * r_div_2)]

            # reshape to (n_hashes, r_div_2, n_dim)
            random_vecs_1 = np.reshape(random_vecs_1, (n_hashes, r_div_2, -1))
            random_vecs_2 = np.reshape(random_vecs_2, (n_hashes, r_div_2, -1))
        else:
            # shape = (n_hashes, r_div_2, n_dim)
            random_vecs_1 = vecs[random_idxs_1]
            random_vecs_2 = vecs[random_idxs_2]

        # shape = (n_dim, n_hashes, r_div_2)
        return np.transpose(random_vecs_2 - random_vecs_1, axes=[2, 0, 1])
Esempio n. 7
0
 def JoinHeads(x):  # pylint: disable=invalid-name
     return np.reshape(np.transpose(x, (0, 2, 1, 3)),
                       (nbatch, -1, n_heads * d_head))
Esempio n. 8
0
 def SplitHeads(x):
     return np.transpose(np.reshape(x, (nbatch, -1, n_heads, d_head)),
                         (0, 2, 1, 3))