Ejemplo n.º 1
0
    def __call__(self,
                 input_ids,
                 input_mask,
                 type_ids,
                 masked_lm_positions,
                 masked_lm_labels,
                 masked_lm_weights,
                 next_sentence_labels,
                 deterministic=False):
        """Applies pre-training model on inputs.

    Args:
      input_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] tokenized inputs.
      input_mask: <bool>[BATCH_SIZE, MAX_SEQ_LENGTH] mask separating actual
        inputs from padding. Only used by BERT.
      type_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] Ids partitioning input into
        different types.
      masked_lm_positions: <int>[BATCH_SIZE, MAX_PREDICTIONS_PER_SEQ] indices
        indicating which inputs are masked.
      masked_lm_labels: <int>[BATCH_SIZE, MAX_PREDICTIONS_PER_SEQ] true labels
        for masked inputs.
      masked_lm_weights: <float>[BATCH_SIZE, MAX_PREDICTIONS_PER_SEQ] relative
        weighting for masked inputs.
      next_sentence_labels: <int>[BATCH_SIZE, 1] Labels for next sentence
        prediction task.
      deterministic: Whether or not to apply dropout to input.

    Returns:
      Loss and metrics for given inputs.
    """
        sequence_output, pooled_output = EncoderModel(
            self.config, random_seed=self.random_seed,
            name="encoder")(input_ids,
                            input_mask,
                            type_ids,
                            deterministic=deterministic)

        masked_lm_output = layers.gather(sequence_output, masked_lm_positions)
        masked_lm_output = nn.Dense(self.config.d_emb,
                                    kernel_init=default_kernel_init,
                                    name="predictions_dense")(masked_lm_output)
        masked_lm_output = nn.gelu(masked_lm_output)
        masked_lm_output = nn.LayerNorm(
            epsilon=LAYER_NORM_EPSILON,
            name="predictions_layer_norm")(masked_lm_output)
        masked_lm_logits = layers.OutputProjection(
            kernel=self._get_embedding_table(),
            name="predictions_output")(masked_lm_output)

        next_sentence_logits = layers.OutputProjection(
            n_out=2, kernel_init=default_kernel_init,
            name="classification")(pooled_output)

        return _compute_pretraining_metrics(masked_lm_logits,
                                            next_sentence_logits,
                                            masked_lm_labels,
                                            masked_lm_weights,
                                            next_sentence_labels)
Ejemplo n.º 2
0
    def test_kernel_output_projection(self):
        batch_size = 2
        max_seq_length = 14
        hidden_dim = 8
        rng = jax.random.PRNGKey(0)

        rng, kernel_rng = jax.random.split(rng)
        kernel = jax.random.uniform(kernel_rng, (NUM_CLASSES, hidden_dim),
                                    minval=-1.0,
                                    maxval=1.0)
        kernel_projection = layers.OutputProjection(kernel=jnp.asarray(kernel),
                                                    name="predictions_output")
        init_batch = {
            "inputs": jnp.ones((1, max_seq_length, hidden_dim), jnp.float32)
        }
        params = init_layer_variables(rng, kernel_projection,
                                      init_batch)["params"]

        expected_keys = {"output_bias"}
        self.assertEqual(params.keys(), expected_keys)

        rng, init_rng = jax.random.split(rng)
        inputs = jax.random.randint(init_rng,
                                    (batch_size, max_seq_length, hidden_dim),
                                    minval=0,
                                    maxval=10)
        outputs = kernel_projection.apply({"params": params}, inputs=inputs)
        self.assertEqual(outputs.shape,
                         (batch_size, max_seq_length, NUM_CLASSES))
Ejemplo n.º 3
0
    def test_classification_output_projection(self):
        batch_size = 2
        max_seq_length = 14
        hidden_dim = 8
        rng = jax.random.PRNGKey(0)

        classification_projection = layers.OutputProjection(
            n_out=NUM_CLASSES, name="classification")
        init_batch = {
            "inputs": jnp.ones((1, max_seq_length, hidden_dim), jnp.float32)
        }
        params = init_layer_variables(rng, classification_projection,
                                      init_batch)["params"]

        expected_keys = {"output_bias", "output_kernel"}
        self.assertEqual(params.keys(), expected_keys)

        rng, init_rng = jax.random.split(rng)
        inputs = jax.random.randint(init_rng,
                                    (batch_size, max_seq_length, hidden_dim),
                                    minval=0,
                                    maxval=10)
        outputs = classification_projection.apply({"params": params},
                                                  inputs=inputs)
        self.assertEqual(outputs.shape,
                         (batch_size, max_seq_length, NUM_CLASSES))
Ejemplo n.º 4
0
  def __call__(
      self,
      input_ids,
      input_mask,
      type_ids,
      labels = None,
      deterministic = False
  ):
    """Applies model for sequence classification.

    Args:
      input_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] tokenized inputs.
      input_mask: <bool>[BATCH_SIZE, MAX_SEQ_LENGTH] mask separating actual
        inputs from padding. Only used by BERT.
      type_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] Ids partitioning input into
        different types.
      labels: True labels associated with inputs. Generally only required for
        training. Shape depends on task type:
        * Classification: <int>[BATCH_SIZE]
        * Regression: <float>[BATCH_SIZE]
      deterministic: Whether or not to apply dropout to input.

    Returns:
      * If labels supplied (training mode): Model loss and metrics.
      * If no labels supplied (prediction / evaluation mode): Logits of shape
        <float>[BATCH_SIZE, n_classes].
    """
    _, pooled_output = EncoderModel(
        self.config, name="encoder")(
            input_ids, input_mask, type_ids, deterministic=deterministic)

    logits = layers.OutputProjection(
        n_out=self.n_classes,
        kernel_init=default_kernel_init,
        name="classification")(
            pooled_output)

    if labels is None:
      # Code path used during evaluation or prediction; metrics can be computed
      # from logits by the caller.
      return logits

    # Code path used during training.
    if self.config.dataset_name == "glue/stsb":  # Regression task
      loss = jnp.mean((logits[Ellipsis, 0] - labels)**2)
      return {"loss": loss, "num_labels": labels.size}
    else:  # Classification task
      logits = nn.log_softmax(logits)
      loss = -jnp.mean(
          jnp.sum(onehot(labels, logits.shape[-1]) * logits, axis=-1))
      correct_predictions = jnp.sum(logits.argmax(-1) == labels)
      return {
          "loss": loss,
          "correct_predictions": correct_predictions,
          "num_labels": labels.size
      }