def forward_with_state(self, x, weights, state, rng): batch_size, length = x.shape[0], x.shape[1] max_pos = min(self._bases)**self._n_digits rng1, rng2, rng3 = math.random.split(rng, 3) assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length, max_pos) positions = jnp.arange(0, length)[None, :] if self._mode == 'train': # In 1% of training cases still start from 0 to be exactly as in eval. start_from_nonzero = jax.random.randint( rng1, (batch_size, ), 0, self._start_from_zero_one_in) start_from_nonzero = jnp.minimum(1, start_from_nonzero) random_start = jax.random.randint(rng2, (batch_size, ), 0, max_pos - length) random_start *= start_from_nonzero positions += random_start[:, None] res = [] for bn, base in enumerate(self._bases): pos_embeddings = [] cur_positions = positions for i in range(self._n_digits): cur_indices = jnp.mod(cur_positions, base) cur_positions = cur_positions // base s = weights[bn][i] pos_embeddings.append( cur_indices.astype(jnp.float32)[:, :, None] * s) embeddings = jnp.concatenate(pos_embeddings, axis=-1) if self._mode == 'train': base_dropout = jax.random.randint(rng3, (batch_size, ), 0, self._base_dropout_one_in) base_dropout = jnp.minimum(1, base_dropout).astype(jnp.float32) embeddings *= base_dropout[:, None, None] res.append(embeddings) res = sum(res) + jnp.zeros_like(x) return jnp.concatenate([x, res], axis=-1), state
def forward_with_state(self, x, weights=layer_base.EMPTY_WEIGHTS, state=layer_base.EMPTY_STATE, rng=None, **kwargs): length = np.shape(x)[1] max_pos = self._base**self._n_digits assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length, max_pos) positions = np.arange(0, length) if self._mode == 'train': positions += jax.random.randint(rng, (), 0, max_pos - length) pos_embeddings = [] cur_positions = positions for i in range(self._n_digits): cur_indices = np.mod(cur_positions, self._base) cur_positions //= self._base pos_embeddings.append(np.take(weights[i], cur_indices, axis=0)) embeddings = np.concatenate(pos_embeddings, axis=-1) return (x + embeddings[None, :, :], state)