Пример #1
0
    def __call__(self,
                 input_ids,
                 input_mask,
                 type_ids,
                 masked_lm_positions,
                 masked_lm_labels,
                 masked_lm_weights,
                 next_sentence_labels,
                 deterministic=False):
        """Applies pre-training model on inputs.

    Args:
      input_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] tokenized inputs.
      input_mask: <bool>[BATCH_SIZE, MAX_SEQ_LENGTH] mask separating actual
        inputs from padding. Only used by BERT.
      type_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] Ids partitioning input into
        different types.
      masked_lm_positions: <int>[BATCH_SIZE, MAX_PREDICTIONS_PER_SEQ] indices
        indicating which inputs are masked.
      masked_lm_labels: <int>[BATCH_SIZE, MAX_PREDICTIONS_PER_SEQ] true labels
        for masked inputs.
      masked_lm_weights: <float>[BATCH_SIZE, MAX_PREDICTIONS_PER_SEQ] relative
        weighting for masked inputs.
      next_sentence_labels: <int>[BATCH_SIZE, 1] Labels for next sentence
        prediction task.
      deterministic: Whether or not to apply dropout to input.

    Returns:
      Loss and metrics for given inputs.
    """
        sequence_output, pooled_output = EncoderModel(
            self.config, random_seed=self.random_seed,
            name="encoder")(input_ids,
                            input_mask,
                            type_ids,
                            deterministic=deterministic)

        masked_lm_output = layers.gather(sequence_output, masked_lm_positions)
        masked_lm_output = nn.Dense(self.config.d_emb,
                                    kernel_init=default_kernel_init,
                                    name="predictions_dense")(masked_lm_output)
        masked_lm_output = nn.gelu(masked_lm_output)
        masked_lm_output = nn.LayerNorm(
            epsilon=LAYER_NORM_EPSILON,
            name="predictions_layer_norm")(masked_lm_output)
        masked_lm_logits = layers.OutputProjection(
            kernel=self._get_embedding_table(),
            name="predictions_output")(masked_lm_output)

        next_sentence_logits = layers.OutputProjection(
            n_out=2, kernel_init=default_kernel_init,
            name="classification")(pooled_output)

        return _compute_pretraining_metrics(masked_lm_logits,
                                            next_sentence_logits,
                                            masked_lm_labels,
                                            masked_lm_weights,
                                            next_sentence_labels)
Пример #2
0
    def test_gather_incorrect_batch_sizes(self):
        example = jnp.arange(12.).reshape(4, 3)
        # Shape [BATCH_SIZE, MAX_SEQ_LENGTH, HIDDEN_DIM] = [2,4,3].
        batch = jnp.array([example, -example])

        with self.assertRaisesRegex(
                ValueError,
                "Input sequence and indices must have the same batch size"):
            # Shape [BATCH_SIZE, MAX_PREDICTIONS_PER_SEQ] = [1,2].
            indices = jnp.array([[1, 2]])
            _ = layers.gather(sequence=batch, indices=indices)
Пример #3
0
    def test_gather_bad_indices(self):
        example = jnp.arange(12.).reshape(4, 3)
        # Shape [BATCH_SIZE, MAX_SEQ_LENGTH, HIDDEN_DIM] = [2,4,3].
        batch = jnp.array([example, -example])

        with self.assertRaisesRegex(
                ValueError,
                "predictions per sequence cannot be greater than the maximum sequence"
        ):
            # Shape [BATCH_SIZE, MAX_PREDICTIONS_PER_SEQ] = [2,5].
            indices = jnp.array([jnp.arange(5), jnp.arange(5)])
            _ = layers.gather(sequence=batch, indices=indices)
Пример #4
0
    def test_gather(self):
        example = jnp.arange(12.).reshape(4, 3)
        # Shape [BATCH_SIZE, MAX_SEQ_LENGTH, HIDDEN_DIM] = [2,4,3].
        batch = jnp.array([example, -example])

        # Shape [BATCH_SIZE, MAX_PREDICTIONS_PER_SEQ] = [2,2].
        indices = jnp.array([[0, 3], [1, 2]])

        outputs = layers.gather(sequence=batch, indices=indices)

        # Shape [BATCH_SIZE * MAX_PREDICTIONS_PER_SEQ, HIDDEN_DIM] = [4,3]
        self.assertEqual(outputs.shape, (4, 3))

        expected = jnp.array(
            [[0, 1, 2], [9, 10, 11], [-3, -4, -5], [-6, -7, -8]],
            dtype=jnp.float32)
        np.testing.assert_allclose(outputs, expected, atol=1e-12)