Пример #1
0
 def learning_rate(step):
     """Step to learning rate function."""
     ret = 1.0
     for name in factors:
         if name == 'constant':
             ret *= constant
         elif name == 'linear_warmup':
             ret *= jnp.minimum(1.0, step / warmup_steps)
         elif name == 'rsqrt_decay':
             ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
         elif name == 'rsqrt_normalized_decay':
             ret *= jnp.sqrt(warmup_steps)
             ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
         elif name == 'decay_every':
             ret *= (decay_factor**(step // steps_per_decay))
         elif name == 'cosine_decay':
             progress = jnp.maximum(0.0, (step - warmup_steps) /
                                    float(steps_per_cycle))
             ret *= (0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0))))
         else:
             raise ValueError('Unknown factor %s.' % name)
     # TODO(henrykm): return float(jnp.max(minimum, ret)) would be
     # better but causes TypeError: 'numpy.float64' object cannot
     # be interpreted as an integer
     if ret <= minimum:
         return minimum
     return ret
Пример #2
0
 def Sinusoidal_Embeddings(positions):
     inv_freq = 1 / (10000**(jnp.arange(0.0, d_feature, 2.0) / d_feature))
     sinusoid_freq = jnp.einsum('i,j->ij', positions, inv_freq)
     pos_emb = jnp.concatenate(
         [jnp.sin(sinusoid_freq),
          jnp.cos(sinusoid_freq)], axis=1)
     return pos_emb
Пример #3
0
 def _sincos(self, start, length, d_feature):
     """Create the sin-cos tensor of shape [1, length, d_feature]."""
     position = jnp.arange(0, length)[:, None] + start
     div_term = jnp.exp(
         jnp.arange(0, d_feature, 2) * -(jnp.log(10000.0) / d_feature))
     sin = jnp.sin(position * div_term)
     cos = jnp.cos(position * div_term)
     pe = jnp.concatenate([sin, cos], axis=1)
     return pe[None, :, :]  # [1, length, d_feature]
Пример #4
0
    def _get_embeddings(self, lo: int, hi: int, depth, rng=None):
        """Get embeddings float[length, depth].

    Args:
      lo: where to start sampling
      hi: where to stop sampling
      depth: embedding depth
      rng: rng for random phase

    Returns:
      embeddings: float[length, depth]
    """
        noise = self._get_noise(lo, hi, (depth + 1) // 2)
        # Make the stddev around 1 after 1/drift.
        noise = noise * self._drift**.5

        t, c = np.mgrid[lo:hi, :depth]
        # Make even channels cos, odd channels sin:
        c_div_2, c_mod_2 = divmod(c, 2)
        # Off-by-one correction for odd depth:
        drift = self._drift
        if depth > 2:
            drift = drift**(((depth + 1) // 2) / (depth // 2))
        # Spend roughly half the frequencies on noise:
        freq = jnp.geomspace(.5, .5 * drift**2, num=(depth + 1) // 2)[c_div_2]
        cycles = c_mod_2 / 4 + freq * t + noise[:, c_div_2[0, :]] / 4
        assert cycles.shape == (hi - lo, depth), cycles.shape

        # Get random phases:
        if self._affine:
            assert rng is not None
            cycles = cycles + trax.fastmath.random.uniform(
                rng, (
                    1,
                    depth,
                ), minval=0, maxval=1)

        # Convert from cycles to radians:
        embeddings = jnp.cos(jnp.pi * 2 * cycles)

        # Set the last channels to the time bin features:
        if self._time_bin_length is not None:
            inter_bin_idx, intra_bin_idx = divmod(t[:, -1:],
                                                  self._time_bin_length)
            bin_parity = inter_bin_idx % 2
            bin_fraction = intra_bin_idx / self._time_bin_length
            embeddings = jnp.concatenate([
                embeddings[:, :-3],
                1 / (1 + inter_bin_idx),
                bin_fraction,
                bin_parity.astype(jnp.float32),
            ], -1)

        assert embeddings.shape == (hi - lo, depth), embeddings.shape
        return embeddings
Пример #5
0
def rotate(x):
    """Rotate function."""
    _, l, d = x.shape
    inv_freq = jnp.exp(jnp.arange(0, d, 2) * -(jnp.log(10000.0) / d))
    positions = jnp.arange(l)
    freqs = jnp.einsum('i,j->ij', positions, inv_freq)
    emb = jnp.concatenate((freqs, freqs), axis=-1)
    cos = jnp.cos(emb)
    sin = jnp.sin(emb)

    def mul(vecs, pos_emb):
        return jnp.einsum('bld,ld->bld', vecs, pos_emb)

    def rotate_half(x):
        x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
        return jnp.concatenate((-x2, x1), axis=x1.ndim - 1)

    return mul(x, cos) + mul(rotate_half(x), sin)
Пример #6
0
def Sinusoidal_Embeddings(positions, d_feature):
  """Sinusoidal Embeddings.

  Computes out of 1-D integer absolute position vector the sinusoidal
  embeddings defined like in paper Attention is all you need (2017).
  Embeddings are shaped (positions, d_feature).

  Args:
    positions: a one-dimensional array of positions.
    d_feature: the number of sin-cos features.

  Returns:
    Positional embeddings.
  """
  inv_freq = 1 / (10000**(jnp.arange(0.0, d_feature, 2.0) / d_feature))
  sinusoid_freq = jnp.einsum('i,j->ij', positions, inv_freq)
  pos_emb = jnp.concatenate(
      [jnp.sin(sinusoid_freq), jnp.cos(sinusoid_freq)], axis=1)
  return pos_emb
Пример #7
0
 def learning_rate(step):
     """Step to learning rate function."""
     ret = 1.0
     for name in factors:
         if name == 'constant':
             ret *= constant
         elif name == 'linear_warmup':
             ret *= jnp.minimum(1.0, step / warmup_steps)
         elif name == 'rsqrt_decay':
             ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
         elif name == 'rsqrt_normalized_decay':
             ret *= jnp.sqrt(warmup_steps)
             ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
         elif name == 'decay_every':
             ret *= (decay_factor**(step // steps_per_decay))
         elif name == 'cosine_decay':
             progress = jnp.maximum(0.0, (step - warmup_steps) /
                                    float(steps_per_cycle))
             ret *= (0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0))))
         else:
             raise ValueError('Unknown factor %s.' % name)
     return float(ret)