def _get_embeddings(self, lo: int, hi: int, depth, rng=None): """Get embeddings float[length, depth]. Args: lo: where to start sampling hi: where to stop sampling depth: embedding depth rng: rng for random phase Returns: embeddings: float[length, depth] """ noise = self._get_noise(lo, hi, (depth + 1) // 2) # Make the stddev around 1 after 1/drift. noise = noise * self._drift**.5 t, c = onp.mgrid[lo:hi, :depth] # Make even channels cos, odd channels sin: c_div_2, c_mod_2 = divmod(c, 2) # Off-by-one correction for odd depth: drift = self._drift if depth > 2: drift = drift**(((depth + 1) // 2) / (depth // 2)) # Spend roughly half the frequencies on noise: freq = np.geomspace(.5, .5 * drift**2, num=(depth + 1) // 2)[c_div_2] cycles = c_mod_2 / 4 + freq * t + noise[:, c_div_2[0, :]] / 4 assert cycles.shape == (hi - lo, depth), cycles.shape # Get random phases: if self._affine: assert rng is not None cycles = cycles + trax.math.random.uniform( rng, ( 1, depth, ), minval=0, maxval=1) # Convert from cycles to radians: embeddings = np.cos(np.pi * 2 * cycles) # Set the last channels to the time bin features: if self._time_bin_length is not None: inter_bin_idx, intra_bin_idx = divmod(t[:, -1:], self._time_bin_length) bin_parity = inter_bin_idx % 2 bin_fraction = intra_bin_idx / self._time_bin_length embeddings = np.concatenate([ embeddings[:, :-3], 1 / (1 + inter_bin_idx), bin_fraction, bin_parity.astype(np.float32), ], -1) assert embeddings.shape == (hi - lo, depth), embeddings.shape return embeddings
def learning_rate(step): """Step to learning rate function.""" ret = 1.0 for name in factors: if name == 'constant': ret *= constant elif name == 'linear_warmup': ret *= np.minimum(1.0, step / warmup_steps) elif name == 'rsqrt_decay': ret /= np.sqrt(np.maximum(step, warmup_steps)) elif name == 'rsqrt_normalized_decay': ret *= np.sqrt(warmup_steps) ret /= np.sqrt(np.maximum(step, warmup_steps)) elif name == 'decay_every': ret *= (decay_factor**(step // steps_per_decay)) elif name == 'cosine_decay': progress = np.maximum(0.0, (step - warmup_steps) / float(steps_per_cycle)) ret *= (0.5 * (1.0 + np.cos(np.pi * (progress % 1.0)))) else: raise ValueError('Unknown factor %s.' % name) ret = np.asarray(ret, dtype=np.float32) return {'learning_rate': ret}