コード例 #1
0
ファイル: rel_attention.py プロジェクト: rohan-viz/trax
 def unpad_borders(v):
   padded_total_len = v.shape[2]
   assert padded_total_len == original_l + chunk_len
   pre_padded, mid, post_padded = split_along_l(
       v, chunk_len, padded_total_len - chunk_len, padded_total_len)
   pre = jnp.take(pre_padded, indices=range(chunk_offset), axis=2)
   post = jnp.take(post_padded, indices=range(last_chunk_len), axis=2)
   return jnp.concatenate([pre, mid, post], axis=2)
コード例 #2
0
ファイル: core.py プロジェクト: wangdongya/trax
  def forward(self, x):
    """Returns embedding vectors corresponding to input token id's.

    Args:
      x: Tensor of token id's.

    Returns:
      Tensor of embedding vectors.
    """
    return jnp.take(self.weights, x, axis=0)
コード例 #3
0
    def forward(self, x):
        """Returns embedding vectors corresponding to input token IDs.

    Args:
      x: Tensor of token IDs.

    Returns:
      Tensor of embedding vectors.
    """
        embedded = jnp.take(self.weights, x, axis=0, mode='clip')
        if self._use_bfloat16:  # Return float32 activations w/ bfloat16 weights.
            embedded = embedded.astype(jnp.float32)
        return embedded
コード例 #4
0
ファイル: rse.py プロジェクト: yliu45/trax
def _shuffle_layer(inputs, shuffle_fn):
  """Shuffles the elements according to bitwise left or right rotation.

  Args:
    inputs: Tensor input from previous layer
    shuffle_fn: Shift function rol or ror

  Returns:
    tf.Tensor: Inputs shifted according to shuffle_fn
  """
  seq_length = inputs.shape[1]
  n_bits = np.int32(np.log(seq_length - 1) / np.log(2.0)) + 1

  indices = np.arange(0, seq_length).astype('int32')
  rev_indices = shuffle_fn(indices, n_bits)
  return jnp.take(inputs, rev_indices, axis=1)
コード例 #5
0
ファイル: rel_attention.py プロジェクト: rohan-viz/trax
 def split_along_l(v, mid_start, mid_end, end):
   pre = jnp.take(v, indices=range(mid_start), axis=2)
   mid = jnp.take(v, indices=range(mid_start, mid_end), axis=2)
   post = jnp.take(v, indices=range(mid_end, end), axis=2)
   return pre, mid, post