Ejemplo n.º 1
0
 def Sinusoidal_Embeddings(positions):
     inv_freq = 1 / (10000**(jnp.arange(0.0, d_feature, 2.0) / d_feature))
     sinusoid_freq = jnp.einsum('i,j->ij', positions, inv_freq)
     pos_emb = jnp.concatenate(
         [jnp.sin(sinusoid_freq),
          jnp.cos(sinusoid_freq)], axis=1)
     return pos_emb
Ejemplo n.º 2
0
 def _sincos(self, start, length, d_feature):
     """Create the sin-cos tensor of shape [1, length, d_feature]."""
     position = jnp.arange(0, length)[:, None] + start
     div_term = jnp.exp(
         jnp.arange(0, d_feature, 2) * -(jnp.log(10000.0) / d_feature))
     sin = jnp.sin(position * div_term)
     cos = jnp.cos(position * div_term)
     pe = jnp.concatenate([sin, cos], axis=1)
     return pe[None, :, :]  # [1, length, d_feature]
Ejemplo n.º 3
0
def rotate(x):
    """Rotate function."""
    _, l, d = x.shape
    inv_freq = jnp.exp(jnp.arange(0, d, 2) * -(jnp.log(10000.0) / d))
    positions = jnp.arange(l)
    freqs = jnp.einsum('i,j->ij', positions, inv_freq)
    emb = jnp.concatenate((freqs, freqs), axis=-1)
    cos = jnp.cos(emb)
    sin = jnp.sin(emb)

    def mul(vecs, pos_emb):
        return jnp.einsum('bld,ld->bld', vecs, pos_emb)

    def rotate_half(x):
        x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
        return jnp.concatenate((-x2, x1), axis=x1.ndim - 1)

    return mul(x, cos) + mul(rotate_half(x), sin)
Ejemplo n.º 4
0
def Sinusoidal_Embeddings(positions, d_feature):
  """Sinusoidal Embeddings.

  Computes out of 1-D integer absolute position vector the sinusoidal
  embeddings defined like in paper Attention is all you need (2017).
  Embeddings are shaped (positions, d_feature).

  Args:
    positions: a one-dimensional array of positions.
    d_feature: the number of sin-cos features.

  Returns:
    Positional embeddings.
  """
  inv_freq = 1 / (10000**(jnp.arange(0.0, d_feature, 2.0) / d_feature))
  sinusoid_freq = jnp.einsum('i,j->ij', positions, inv_freq)
  pos_emb = jnp.concatenate(
      [jnp.sin(sinusoid_freq), jnp.cos(sinusoid_freq)], axis=1)
  return pos_emb