def learning_rate(step): """Step to learning rate function.""" ret = 1.0 for name in factors: if name == 'constant': ret *= constant elif name == 'linear_warmup': ret *= jnp.minimum(1.0, step / warmup_steps) elif name == 'rsqrt_decay': ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) elif name == 'rsqrt_normalized_decay': ret *= jnp.sqrt(warmup_steps) ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) elif name == 'decay_every': ret *= (decay_factor**(step // steps_per_decay)) elif name == 'cosine_decay': progress = jnp.maximum(0.0, (step - warmup_steps) / float(steps_per_cycle)) ret *= (0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0)))) else: raise ValueError('Unknown factor %s.' % name) # TODO(henrykm): return float(jnp.max(minimum, ret)) would be # better but causes TypeError: 'numpy.float64' object cannot # be interpreted as an integer if ret <= minimum: return minimum return ret
def Sinusoidal_Embeddings(positions): inv_freq = 1 / (10000**(jnp.arange(0.0, d_feature, 2.0) / d_feature)) sinusoid_freq = jnp.einsum('i,j->ij', positions, inv_freq) pos_emb = jnp.concatenate( [jnp.sin(sinusoid_freq), jnp.cos(sinusoid_freq)], axis=1) return pos_emb
def _sincos(self, start, length, d_feature): """Create the sin-cos tensor of shape [1, length, d_feature].""" position = jnp.arange(0, length)[:, None] + start div_term = jnp.exp( jnp.arange(0, d_feature, 2) * -(jnp.log(10000.0) / d_feature)) sin = jnp.sin(position * div_term) cos = jnp.cos(position * div_term) pe = jnp.concatenate([sin, cos], axis=1) return pe[None, :, :] # [1, length, d_feature]
def _get_embeddings(self, lo: int, hi: int, depth, rng=None): """Get embeddings float[length, depth]. Args: lo: where to start sampling hi: where to stop sampling depth: embedding depth rng: rng for random phase Returns: embeddings: float[length, depth] """ noise = self._get_noise(lo, hi, (depth + 1) // 2) # Make the stddev around 1 after 1/drift. noise = noise * self._drift**.5 t, c = np.mgrid[lo:hi, :depth] # Make even channels cos, odd channels sin: c_div_2, c_mod_2 = divmod(c, 2) # Off-by-one correction for odd depth: drift = self._drift if depth > 2: drift = drift**(((depth + 1) // 2) / (depth // 2)) # Spend roughly half the frequencies on noise: freq = jnp.geomspace(.5, .5 * drift**2, num=(depth + 1) // 2)[c_div_2] cycles = c_mod_2 / 4 + freq * t + noise[:, c_div_2[0, :]] / 4 assert cycles.shape == (hi - lo, depth), cycles.shape # Get random phases: if self._affine: assert rng is not None cycles = cycles + trax.fastmath.random.uniform( rng, ( 1, depth, ), minval=0, maxval=1) # Convert from cycles to radians: embeddings = jnp.cos(jnp.pi * 2 * cycles) # Set the last channels to the time bin features: if self._time_bin_length is not None: inter_bin_idx, intra_bin_idx = divmod(t[:, -1:], self._time_bin_length) bin_parity = inter_bin_idx % 2 bin_fraction = intra_bin_idx / self._time_bin_length embeddings = jnp.concatenate([ embeddings[:, :-3], 1 / (1 + inter_bin_idx), bin_fraction, bin_parity.astype(jnp.float32), ], -1) assert embeddings.shape == (hi - lo, depth), embeddings.shape return embeddings
def rotate(x): """Rotate function.""" _, l, d = x.shape inv_freq = jnp.exp(jnp.arange(0, d, 2) * -(jnp.log(10000.0) / d)) positions = jnp.arange(l) freqs = jnp.einsum('i,j->ij', positions, inv_freq) emb = jnp.concatenate((freqs, freqs), axis=-1) cos = jnp.cos(emb) sin = jnp.sin(emb) def mul(vecs, pos_emb): return jnp.einsum('bld,ld->bld', vecs, pos_emb) def rotate_half(x): x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] return jnp.concatenate((-x2, x1), axis=x1.ndim - 1) return mul(x, cos) + mul(rotate_half(x), sin)
def Sinusoidal_Embeddings(positions, d_feature): """Sinusoidal Embeddings. Computes out of 1-D integer absolute position vector the sinusoidal embeddings defined like in paper Attention is all you need (2017). Embeddings are shaped (positions, d_feature). Args: positions: a one-dimensional array of positions. d_feature: the number of sin-cos features. Returns: Positional embeddings. """ inv_freq = 1 / (10000**(jnp.arange(0.0, d_feature, 2.0) / d_feature)) sinusoid_freq = jnp.einsum('i,j->ij', positions, inv_freq) pos_emb = jnp.concatenate( [jnp.sin(sinusoid_freq), jnp.cos(sinusoid_freq)], axis=1) return pos_emb
def learning_rate(step): """Step to learning rate function.""" ret = 1.0 for name in factors: if name == 'constant': ret *= constant elif name == 'linear_warmup': ret *= jnp.minimum(1.0, step / warmup_steps) elif name == 'rsqrt_decay': ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) elif name == 'rsqrt_normalized_decay': ret *= jnp.sqrt(warmup_steps) ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) elif name == 'decay_every': ret *= (decay_factor**(step // steps_per_decay)) elif name == 'cosine_decay': progress = jnp.maximum(0.0, (step - warmup_steps) / float(steps_per_cycle)) ret *= (0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0)))) else: raise ValueError('Unknown factor %s.' % name) return float(ret)