def __call__(self, input_ids, input_mask, type_ids, masked_lm_positions, masked_lm_labels, masked_lm_weights, next_sentence_labels, deterministic=False): """Applies pre-training model on inputs. Args: input_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] tokenized inputs. input_mask: <bool>[BATCH_SIZE, MAX_SEQ_LENGTH] mask separating actual inputs from padding. Only used by BERT. type_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] Ids partitioning input into different types. masked_lm_positions: <int>[BATCH_SIZE, MAX_PREDICTIONS_PER_SEQ] indices indicating which inputs are masked. masked_lm_labels: <int>[BATCH_SIZE, MAX_PREDICTIONS_PER_SEQ] true labels for masked inputs. masked_lm_weights: <float>[BATCH_SIZE, MAX_PREDICTIONS_PER_SEQ] relative weighting for masked inputs. next_sentence_labels: <int>[BATCH_SIZE, 1] Labels for next sentence prediction task. deterministic: Whether or not to apply dropout to input. Returns: Loss and metrics for given inputs. """ sequence_output, pooled_output = EncoderModel( self.config, random_seed=self.random_seed, name="encoder")(input_ids, input_mask, type_ids, deterministic=deterministic) masked_lm_output = layers.gather(sequence_output, masked_lm_positions) masked_lm_output = nn.Dense(self.config.d_emb, kernel_init=default_kernel_init, name="predictions_dense")(masked_lm_output) masked_lm_output = nn.gelu(masked_lm_output) masked_lm_output = nn.LayerNorm( epsilon=LAYER_NORM_EPSILON, name="predictions_layer_norm")(masked_lm_output) masked_lm_logits = layers.OutputProjection( kernel=self._get_embedding_table(), name="predictions_output")(masked_lm_output) next_sentence_logits = layers.OutputProjection( n_out=2, kernel_init=default_kernel_init, name="classification")(pooled_output) return _compute_pretraining_metrics(masked_lm_logits, next_sentence_logits, masked_lm_labels, masked_lm_weights, next_sentence_labels)
def test_gather_incorrect_batch_sizes(self): example = jnp.arange(12.).reshape(4, 3) # Shape [BATCH_SIZE, MAX_SEQ_LENGTH, HIDDEN_DIM] = [2,4,3]. batch = jnp.array([example, -example]) with self.assertRaisesRegex( ValueError, "Input sequence and indices must have the same batch size"): # Shape [BATCH_SIZE, MAX_PREDICTIONS_PER_SEQ] = [1,2]. indices = jnp.array([[1, 2]]) _ = layers.gather(sequence=batch, indices=indices)
def test_gather_bad_indices(self): example = jnp.arange(12.).reshape(4, 3) # Shape [BATCH_SIZE, MAX_SEQ_LENGTH, HIDDEN_DIM] = [2,4,3]. batch = jnp.array([example, -example]) with self.assertRaisesRegex( ValueError, "predictions per sequence cannot be greater than the maximum sequence" ): # Shape [BATCH_SIZE, MAX_PREDICTIONS_PER_SEQ] = [2,5]. indices = jnp.array([jnp.arange(5), jnp.arange(5)]) _ = layers.gather(sequence=batch, indices=indices)
def test_gather(self): example = jnp.arange(12.).reshape(4, 3) # Shape [BATCH_SIZE, MAX_SEQ_LENGTH, HIDDEN_DIM] = [2,4,3]. batch = jnp.array([example, -example]) # Shape [BATCH_SIZE, MAX_PREDICTIONS_PER_SEQ] = [2,2]. indices = jnp.array([[0, 3], [1, 2]]) outputs = layers.gather(sequence=batch, indices=indices) # Shape [BATCH_SIZE * MAX_PREDICTIONS_PER_SEQ, HIDDEN_DIM] = [4,3] self.assertEqual(outputs.shape, (4, 3)) expected = jnp.array( [[0, 1, 2], [9, 10, 11], [-3, -4, -5], [-6, -7, -8]], dtype=jnp.float32) np.testing.assert_allclose(outputs, expected, atol=1e-12)