Esempio n. 1
0
 def forward(self, x):
   rng = self.rng
   batch_size, length = x.shape[0], x.shape[1]
   max_pos = min(self._bases)**self._n_digits
   rng1, rng2, rng3 = fastmath.random.split(rng, 3)
   assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length, max_pos)
   positions = jnp.arange(0, length)[None, :]
   if self._mode == 'train':
     # In 1% of training cases still start from 0 to be exactly as in eval.
     start_from_nonzero = fastmath.random.randint(
         rng1, (batch_size,), 0, self._start_from_zero_one_in)
     start_from_nonzero = jnp.minimum(1, start_from_nonzero)
     random_start = fastmath.random.randint(
         rng2, (batch_size,), 0, max_pos-length)
     random_start *= start_from_nonzero
     positions += random_start[:, None]
   res = []
   for bn, base in enumerate(self._bases):
     pos_embeddings = []
     cur_positions = positions
     for i in range(self._n_digits):
       cur_indices = jnp.mod(cur_positions, base)
       cur_positions = cur_positions // base
       s = self.weights[bn][i]
       pos_embeddings.append(cur_indices.astype(jnp.float32)[:, :, None] * s)
     embeddings = jnp.concatenate(pos_embeddings, axis=-1)
     if self._mode == 'train':
       base_dropout = fastmath.random.randint(
           rng3, (batch_size,), 0, self._base_dropout_one_in)
       base_dropout = jnp.minimum(1, base_dropout).astype(jnp.float32)
       embeddings *= base_dropout[:, None, None]
     res.append(embeddings)
   res = sum(res) + jnp.zeros_like(x)
   return x + res
Esempio n. 2
0
 def forward(self, x):
     rng = self.rng
     base_weights, start_vec = self.weights
     batch_size, length = x.shape[0], x.shape[1]
     max_pos = min(self._bases)**self._n_digits
     rng1, rng2, rng3 = fastmath.random.split(rng, 3)
     assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length,
                                                               max_pos)
     positions = jnp.arange(0, length)[None, :]
     # In training we'll randomize starts for better generalization.
     # We use the trainable start_vec to compensate and give model a way
     # to learn what is the starting position in a sequence.
     if self._mode == 'train':
         # In 1% of training cases still start from 0 to be exactly as in eval.
         start_from_nonzero = fastmath.random.randint(
             rng1, (batch_size, ), 0, self._start_from_zero_one_in)
         start_from_nonzero = jnp.minimum(1, start_from_nonzero)
         random_start = fastmath.random.randint(rng2, (batch_size, ), 0,
                                                max_pos - length)
         random_start *= start_from_nonzero
         positions += random_start[:, None]
     if self._mode == 'predict':
         positions += self.state
     res = []
     for bn, base in enumerate(self._bases):
         pos_embeddings = []
         cur_positions = positions
         for i in range(self._n_digits):
             cur_indices = jnp.mod(cur_positions, base)
             cur_positions = cur_positions // base
             s = base_weights[bn][i]
             pos_embeddings.append(
                 cur_indices.astype(jnp.float32)[:, :, None] * s)
         embeddings = jnp.concatenate(pos_embeddings, axis=-1)
         if self._mode == 'train':
             base_dropout = fastmath.random.randint(
                 rng3, (batch_size, ), 0, self._base_dropout_one_in)
             base_dropout = jnp.minimum(1, base_dropout).astype(jnp.float32)
             embeddings *= base_dropout[:, None, None]
         res.append(embeddings)
     res = sum(res)  # Sum embeddings from all bases.
     # Add start_vec to the first position only to mark it as starting.
     res0 = res[:, 0, :][:, None, :]
     start_pos = res0 + start_vec
     if self._mode == 'predict':
         start_pos = jnp.where(jnp.equal(self.state, 0), start_pos, res0)
         self.state += length  # Add input length to state.
     res = jnp.concatenate([start_pos, res[:, 1:, :]], axis=1)
     return x + res