예제 #1
0
 def forward_with_state(self, x, weights, state, rng):
     batch_size, length = x.shape[0], x.shape[1]
     max_pos = min(self._bases)**self._n_digits
     rng1, rng2, rng3 = math.random.split(rng, 3)
     assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length,
                                                               max_pos)
     positions = jnp.arange(0, length)[None, :]
     if self._mode == 'train':
         # In 1% of training cases still start from 0 to be exactly as in eval.
         start_from_nonzero = jax.random.randint(
             rng1, (batch_size, ), 0, self._start_from_zero_one_in)
         start_from_nonzero = jnp.minimum(1, start_from_nonzero)
         random_start = jax.random.randint(rng2, (batch_size, ), 0,
                                           max_pos - length)
         random_start *= start_from_nonzero
         positions += random_start[:, None]
     res = []
     for bn, base in enumerate(self._bases):
         pos_embeddings = []
         cur_positions = positions
         for i in range(self._n_digits):
             cur_indices = jnp.mod(cur_positions, base)
             cur_positions = cur_positions // base
             s = weights[bn][i]
             pos_embeddings.append(
                 cur_indices.astype(jnp.float32)[:, :, None] * s)
         embeddings = jnp.concatenate(pos_embeddings, axis=-1)
         if self._mode == 'train':
             base_dropout = jax.random.randint(rng3, (batch_size, ), 0,
                                               self._base_dropout_one_in)
             base_dropout = jnp.minimum(1, base_dropout).astype(jnp.float32)
             embeddings *= base_dropout[:, None, None]
         res.append(embeddings)
     res = sum(res) + jnp.zeros_like(x)
     return jnp.concatenate([x, res], axis=-1), state
예제 #2
0
 def forward_with_state(self,
                        x,
                        weights=layer_base.EMPTY_WEIGHTS,
                        state=layer_base.EMPTY_STATE,
                        rng=None,
                        **kwargs):
     length = np.shape(x)[1]
     max_pos = self._base**self._n_digits
     assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length,
                                                               max_pos)
     positions = np.arange(0, length)
     if self._mode == 'train':
         positions += jax.random.randint(rng, (), 0, max_pos - length)
     pos_embeddings = []
     cur_positions = positions
     for i in range(self._n_digits):
         cur_indices = np.mod(cur_positions, self._base)
         cur_positions //= self._base
         pos_embeddings.append(np.take(weights[i], cur_indices, axis=0))
     embeddings = np.concatenate(pos_embeddings, axis=-1)
     return (x + embeddings[None, :, :], state)