def Sinusoidal_Embeddings(positions): inv_freq = 1 / (10000**(jnp.arange(0.0, d_feature, 2.0) / d_feature)) sinusoid_freq = jnp.einsum('i,j->ij', positions, inv_freq) pos_emb = jnp.concatenate( [jnp.sin(sinusoid_freq), jnp.cos(sinusoid_freq)], axis=1) return pos_emb
def _sincos(self, start, length, d_feature): """Create the sin-cos tensor of shape [1, length, d_feature].""" position = jnp.arange(0, length)[:, None] + start div_term = jnp.exp( jnp.arange(0, d_feature, 2) * -(jnp.log(10000.0) / d_feature)) sin = jnp.sin(position * div_term) cos = jnp.cos(position * div_term) pe = jnp.concatenate([sin, cos], axis=1) return pe[None, :, :] # [1, length, d_feature]
def rotate(x): """Rotate function.""" _, l, d = x.shape inv_freq = jnp.exp(jnp.arange(0, d, 2) * -(jnp.log(10000.0) / d)) positions = jnp.arange(l) freqs = jnp.einsum('i,j->ij', positions, inv_freq) emb = jnp.concatenate((freqs, freqs), axis=-1) cos = jnp.cos(emb) sin = jnp.sin(emb) def mul(vecs, pos_emb): return jnp.einsum('bld,ld->bld', vecs, pos_emb) def rotate_half(x): x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] return jnp.concatenate((-x2, x1), axis=x1.ndim - 1) return mul(x, cos) + mul(rotate_half(x), sin)
def Sinusoidal_Embeddings(positions, d_feature): """Sinusoidal Embeddings. Computes out of 1-D integer absolute position vector the sinusoidal embeddings defined like in paper Attention is all you need (2017). Embeddings are shaped (positions, d_feature). Args: positions: a one-dimensional array of positions. d_feature: the number of sin-cos features. Returns: Positional embeddings. """ inv_freq = 1 / (10000**(jnp.arange(0.0, d_feature, 2.0) / d_feature)) sinusoid_freq = jnp.einsum('i,j->ij', positions, inv_freq) pos_emb = jnp.concatenate( [jnp.sin(sinusoid_freq), jnp.cos(sinusoid_freq)], axis=1) return pos_emb