Exemplo n.º 1
0
    def _get_embeddings(self, lo: int, hi: int, depth, rng=None):
        """Get embeddings float[length, depth].

    Args:
      lo: where to start sampling
      hi: where to stop sampling
      depth: embedding depth
      rng: rng for random phase

    Returns:
      embeddings: float[length, depth]
    """
        noise = self._get_noise(lo, hi, (depth + 1) // 2)
        # Make the stddev around 1 after 1/drift.
        noise = noise * self._drift**.5

        t, c = onp.mgrid[lo:hi, :depth]
        # Make even channels cos, odd channels sin:
        c_div_2, c_mod_2 = divmod(c, 2)
        # Off-by-one correction for odd depth:
        drift = self._drift
        if depth > 2:
            drift = drift**(((depth + 1) // 2) / (depth // 2))
        # Spend roughly half the frequencies on noise:
        freq = np.geomspace(.5, .5 * drift**2, num=(depth + 1) // 2)[c_div_2]
        cycles = c_mod_2 / 4 + freq * t + noise[:, c_div_2[0, :]] / 4
        assert cycles.shape == (hi - lo, depth), cycles.shape

        # Get random phases:
        if self._affine:
            assert rng is not None
            cycles = cycles + trax.math.random.uniform(
                rng, (
                    1,
                    depth,
                ), minval=0, maxval=1)

        # Convert from cycles to radians:
        embeddings = np.cos(np.pi * 2 * cycles)

        # Set the last channels to the time bin features:
        if self._time_bin_length is not None:
            inter_bin_idx, intra_bin_idx = divmod(t[:, -1:],
                                                  self._time_bin_length)
            bin_parity = inter_bin_idx % 2
            bin_fraction = intra_bin_idx / self._time_bin_length
            embeddings = np.concatenate([
                embeddings[:, :-3],
                1 / (1 + inter_bin_idx),
                bin_fraction,
                bin_parity.astype(np.float32),
            ], -1)

        assert embeddings.shape == (hi - lo, depth), embeddings.shape
        return embeddings
Exemplo n.º 2
0
 def learning_rate(step):
     """Step to learning rate function."""
     ret = 1.0
     for name in factors:
         if name == 'constant':
             ret *= constant
         elif name == 'linear_warmup':
             ret *= np.minimum(1.0, step / warmup_steps)
         elif name == 'rsqrt_decay':
             ret /= np.sqrt(np.maximum(step, warmup_steps))
         elif name == 'rsqrt_normalized_decay':
             ret *= np.sqrt(warmup_steps)
             ret /= np.sqrt(np.maximum(step, warmup_steps))
         elif name == 'decay_every':
             ret *= (decay_factor**(step // steps_per_decay))
         elif name == 'cosine_decay':
             progress = np.maximum(0.0, (step - warmup_steps) /
                                   float(steps_per_cycle))
             ret *= (0.5 * (1.0 + np.cos(np.pi * (progress % 1.0))))
         else:
             raise ValueError('Unknown factor %s.' % name)
     ret = np.asarray(ret, dtype=np.float32)
     return {'learning_rate': ret}