def forward(self, x): rng = self.rng batch_size, length = x.shape[0], x.shape[1] max_pos = min(self._bases)**self._n_digits rng1, rng2, rng3 = fastmath.random.split(rng, 3) assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length, max_pos) positions = jnp.arange(0, length)[None, :] if self._mode == 'train': # In 1% of training cases still start from 0 to be exactly as in eval. start_from_nonzero = fastmath.random.randint( rng1, (batch_size,), 0, self._start_from_zero_one_in) start_from_nonzero = jnp.minimum(1, start_from_nonzero) random_start = fastmath.random.randint( rng2, (batch_size,), 0, max_pos-length) random_start *= start_from_nonzero positions += random_start[:, None] res = [] for bn, base in enumerate(self._bases): pos_embeddings = [] cur_positions = positions for i in range(self._n_digits): cur_indices = jnp.mod(cur_positions, base) cur_positions = cur_positions // base s = self.weights[bn][i] pos_embeddings.append(cur_indices.astype(jnp.float32)[:, :, None] * s) embeddings = jnp.concatenate(pos_embeddings, axis=-1) if self._mode == 'train': base_dropout = fastmath.random.randint( rng3, (batch_size,), 0, self._base_dropout_one_in) base_dropout = jnp.minimum(1, base_dropout).astype(jnp.float32) embeddings *= base_dropout[:, None, None] res.append(embeddings) res = sum(res) + jnp.zeros_like(x) return x + res
def forward(self, x): rng = self.rng base_weights, start_vec = self.weights batch_size, length = x.shape[0], x.shape[1] max_pos = min(self._bases)**self._n_digits rng1, rng2, rng3 = fastmath.random.split(rng, 3) assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length, max_pos) positions = jnp.arange(0, length)[None, :] # In training we'll randomize starts for better generalization. # We use the trainable start_vec to compensate and give model a way # to learn what is the starting position in a sequence. if self._mode == 'train': # In 1% of training cases still start from 0 to be exactly as in eval. start_from_nonzero = fastmath.random.randint( rng1, (batch_size, ), 0, self._start_from_zero_one_in) start_from_nonzero = jnp.minimum(1, start_from_nonzero) random_start = fastmath.random.randint(rng2, (batch_size, ), 0, max_pos - length) random_start *= start_from_nonzero positions += random_start[:, None] if self._mode == 'predict': positions += self.state res = [] for bn, base in enumerate(self._bases): pos_embeddings = [] cur_positions = positions for i in range(self._n_digits): cur_indices = jnp.mod(cur_positions, base) cur_positions = cur_positions // base s = base_weights[bn][i] pos_embeddings.append( cur_indices.astype(jnp.float32)[:, :, None] * s) embeddings = jnp.concatenate(pos_embeddings, axis=-1) if self._mode == 'train': base_dropout = fastmath.random.randint( rng3, (batch_size, ), 0, self._base_dropout_one_in) base_dropout = jnp.minimum(1, base_dropout).astype(jnp.float32) embeddings *= base_dropout[:, None, None] res.append(embeddings) res = sum(res) # Sum embeddings from all bases. # Add start_vec to the first position only to mark it as starting. res0 = res[:, 0, :][:, None, :] start_pos = res0 + start_vec if self._mode == 'predict': start_pos = jnp.where(jnp.equal(self.state, 0), start_pos, res0) self.state += length # Add input length to state. res = jnp.concatenate([start_pos, res[:, 1:, :]], axis=1) return x + res