Beispiel #1
0
    def apply(self,
              input_ids,
              input_mask,
              type_ids,
              masked_lm_positions=None,
              masked_lm_labels=None,
              masked_lm_weights=None,
              next_sentence_labels=None,
              *,
              config,
              deterministic=False):
        """Applies BERT for pre-training."""
        bert = BertModel.shared(config=config, name='bert')
        sequence_output, pooled_output = bert(input_ids,
                                              input_mask,
                                              type_ids,
                                              deterministic=deterministic)
        if masked_lm_positions is None:
            return sequence_output, pooled_output

        # Masked LM
        masked_lm_input = GatherIndexes(sequence_output, masked_lm_positions)
        masked_lm_input = nn.Dense(masked_lm_input,
                                   config.hidden_size,
                                   kernel_init=get_kernel_init(config),
                                   name='predictions_transform_dense')
        masked_lm_input = get_hidden_activation(config)(masked_lm_input)
        masked_lm_input = nn.LayerNorm(masked_lm_input,
                                       epsilon=LAYER_NORM_EPSILON,
                                       name='predictions_transform_layernorm')
        masked_lm_logits = layers.OutputProjection(
            masked_lm_input,
            kernel=bert.get_embedding_table(),
            name='predictions_output')

        # Next-sentence prediction
        next_sentence_logits = layers.OutputProjection(
            pooled_output,
            n_out=2,
            kernel_init=get_kernel_init(config),
            name='classification')

        if masked_lm_labels is None or next_sentence_labels is None:
            return masked_lm_logits, next_sentence_logits
        else:
            return self._compute_metrics(masked_lm_logits,
                                         next_sentence_logits,
                                         masked_lm_labels, masked_lm_weights,
                                         next_sentence_labels)