예제 #1
0
    def get_prediction_module(self, bert_model, features, is_training,
                              percent_done):
        n_classes = len(self._get_label_mapping())
        reprs = bert_model.get_sequence_output()
        reprs = pretrain_helpers.gather_positions(
            reprs, features[self.name + "_labeled_positions"])
        seq_lengths = tf.cast(
            tf.reduce_sum(features[self.name + "_labels_mask"], axis=1),
            tf.int32)
        logits = tf.layers.dense(reprs, n_classes)

        with tf.variable_scope("crf", reuse=tf.AUTO_REUSE):
            trans_val = tf.get_variable("transition",
                                        shape=[n_classes, n_classes],
                                        dtype=tf.float32)
        predict_ids, _ = crf.crf_decode(logits, trans_val, seq_lengths)
        actual_ids = features[self.name + "_labels"]
        log_likelihood, _ = crf.crf_log_likelihood(
            inputs=logits,
            tag_indices=actual_ids,
            sequence_lengths=seq_lengths,
            transition_params=trans_val)
        losses = -log_likelihood

        return losses, dict(
            loss=losses,
            logits=logits,
            predictions=predict_ids,
            labels=features[self.name + "_labels"],
            labels_mask=features[self.name + "_labels_mask"],
            labeled_positions=features[self.name + "_labeled_positions"],
            eid=features[self.name + "_eid"],
        )
예제 #2
0
    def _get_masked_lm_output(self, inputs: pretrain_data.Inputs, model):
        """Masked language modeling softmax layer."""
        masked_lm_weights = inputs.masked_lm_weights
        with tf.variable_scope("generator_predictions"):
            if self._config.uniform_generator:
                logits = tf.zeros(self._bert_config.vocab_size)
                logits_tiled = tf.zeros(
                    modeling.get_shape_list(inputs.masked_lm_ids) +
                    [self._bert_config.vocab_size])
                logits_tiled += tf.reshape(
                    logits, [1, 1, self._bert_config.vocab_size])
                logits = logits_tiled
            else:
                relevant_hidden = pretrain_helpers.gather_positions(
                    model.get_sequence_output(), inputs.masked_lm_positions)
                hidden = tf.layers.dense(
                    relevant_hidden,
                    units=modeling.get_shape_list(
                        model.get_embedding_table())[-1],
                    activation=modeling.get_activation(
                        self._bert_config.hidden_act),
                    kernel_initializer=modeling.create_initializer(
                        self._bert_config.initializer_range),
                )
                hidden = modeling.layer_norm(hidden)
                output_bias = tf.get_variable(
                    "output_bias",
                    shape=[self._bert_config.vocab_size],
                    initializer=tf.zeros_initializer())
                logits_embed = tf.matmul(hidden,
                                         model.get_embedding_table(),
                                         transpose_b=True)
                logits = tf.nn.bias_add(logits_embed, output_bias)

            oh_labels = tf.one_hot(inputs.masked_lm_ids,
                                   depth=self._bert_config.vocab_size,
                                   dtype=tf.float32)

            probs = tf.nn.softmax(logits)
            log_probs = tf.nn.log_softmax(logits)
            label_log_probs = -tf.reduce_sum(log_probs * oh_labels, axis=-1)

            numerator = tf.reduce_sum(inputs.masked_lm_weights *
                                      label_log_probs)
            denominator = tf.reduce_sum(masked_lm_weights) + 1e-6
            loss = numerator / denominator
            preds = tf.argmax(log_probs, axis=-1, output_type=tf.int32)

            MLMOutput = collections.namedtuple(
                "MLMOutput",
                ["logits", "probs", "loss", "per_example_loss", "preds"])
            return MLMOutput(logits=logits,
                             probs=probs,
                             per_example_loss=label_log_probs,
                             loss=loss,
                             preds=preds), logits_embed
예제 #3
0
 def get_prediction_module(self, bert_model, features, is_training,
                           percent_done):
     n_classes = len(self._get_label_mapping())
     reprs = bert_model.get_sequence_output()
     reprs = pretrain_helpers.gather_positions(
         reprs, features[self.name + "_labeled_positions"])
     logits = tf.layers.dense(reprs, n_classes)
     losses = tf.nn.softmax_cross_entropy_with_logits(labels=tf.one_hot(
         features[self.name + "_labels"], n_classes),
                                                      logits=logits)
     losses *= features[self.name + "_labels_mask"]
     losses = tf.reduce_sum(losses, axis=-1)
     return losses, dict(
         loss=losses,
         logits=logits,
         predictions=tf.argmax(logits, axis=-1),
         labels=features[self.name + "_labels"],
         labels_mask=features[self.name + "_labels_mask"],
         eid=features[self.name + "_eid"],
     )
예제 #4
0
 def _get_masked_lm_output(self, inputs: pretrain_data.Inputs, model):
     """Masked language modeling softmax layer."""
     with tf.variable_scope("generator_predictions"):
         if self._config.uniform_generator:
             logits = tf.zeros(self._bert_config.vocab_size)
             logits_tiled = tf.zeros(
                 modeling.get_shape_list(inputs.masked_lm_ids) +
                 [self._bert_config.vocab_size])
             logits_tiled += tf.reshape(
                 logits, [1, 1, self._bert_config.vocab_size])
             logits = logits_tiled
         else:
             relevant_reprs = pretrain_helpers.gather_positions(
                 model.get_sequence_output(), inputs.masked_lm_positions)
             logits = get_token_logits(relevant_reprs,
                                       model.get_embedding_table(),
                                       self._bert_config)
         return get_softmax_output(logits, inputs.masked_lm_ids,
                                   inputs.masked_lm_weights,
                                   self._bert_config.vocab_size)
  def _get_masked_lm_output(self, inputs: pretrain_data.Inputs, model):
    """Masked language modeling softmax layer."""
    masked_lm_weights = inputs.masked_lm_weights
    with tf.variable_scope("generator_predictions"):
      if self._config.uniform_generator or self._config.identity_generator or self._config.heuristic_generator:
        logits = tf.zeros(self._bert_config.vocab_size)
        logits_tiled = tf.zeros(
            modeling.get_shape_list(inputs.masked_lm_ids) +
            [self._bert_config.vocab_size])
        logits_tiled += tf.reshape(logits, [1, 1, self._bert_config.vocab_size])
        logits = logits_tiled
      else:
        relevant_hidden = pretrain_helpers.gather_positions(
            model.get_sequence_output(), inputs.masked_lm_positions)
        hidden = tf.layers.dense(
            relevant_hidden,
            units=modeling.get_shape_list(model.get_embedding_table())[-1],
            activation=modeling.get_activation(self._bert_config.hidden_act),
            kernel_initializer=modeling.create_initializer(
                self._bert_config.initializer_range))
        hidden = modeling.layer_norm(hidden)
        output_bias = tf.get_variable(
            "output_bias",
            shape=[self._bert_config.vocab_size],
            initializer=tf.zeros_initializer())
        logits = tf.matmul(hidden, model.get_embedding_table(),
                           transpose_b=True)
        logits = tf.nn.bias_add(logits, output_bias)

      oh_labels = tf.one_hot(
          inputs.masked_lm_ids, depth=self._bert_config.vocab_size,
          dtype=tf.float32)

      probs = tf.nn.softmax(logits)

      if self._config.identity_generator:
          identity_logits = tf.zeros(self._bert_config.vocab_size)
          identity_logits_tiled = tf.zeros(
              modeling.get_shape_list(inputs.masked_lm_ids) +
              [self._bert_config.vocab_size])
          masked_identity_weights = tf.one_hot(inputs.masked_lm_ids, depth=self._bert_config.vocab_size, dtype=tf.float32)
          identity_logits_tiled += 25.0 * masked_identity_weights
          identity_logits_tiled += tf.reshape(identity_logits, [1, 1, self._bert_config.vocab_size])
          identity_logits = identity_logits_tiled
          identity_probs = tf.nn.softmax(identity_logits)

          identity_weight = (self.global_step / tf.cast(self._config.num_train_steps, tf.float32)) * self._config.max_identity_weight
          probs = probs * (1 - identity_weight) + identity_probs * identity_weight
          logits = tf.math.log(probs)  # softmax(log(probs)) = probs
      elif self._config.heuristic_generator:
          synonym_logits = tf.zeros(self._bert_config.vocab_size)
          synonym_logits_tiled = tf.zeros(
              modeling.get_shape_list(inputs.masked_lm_ids) +
              [self._bert_config.vocab_size])
          masked_synonym_weights = tf.reduce_sum(
              tf.one_hot(inputs.masked_synonym_ids, depth=self._bert_config.vocab_size, dtype=tf.float32), -2)
          padded_synonym_mask = tf.concat([tf.zeros([1]), tf.ones([self._bert_config.vocab_size - 1])], 0)
          masked_synonym_weights *= tf.expand_dims(tf.expand_dims(padded_synonym_mask, 0), 0)
          synonym_logits_tiled += 25.0 * masked_synonym_weights
          synonym_logits_tiled += tf.reshape(synonym_logits, [1, 1, self._bert_config.vocab_size])
          synonym_logits = synonym_logits_tiled
          synonym_probs = tf.nn.softmax(synonym_logits)

          if self._config.synonym_scheduler_type == 'linear':
              synonym_weight = (self.global_step / tf.cast(self._config.num_train_steps, tf.float32)) * self._config.max_synonym_weight
              probs = probs * (1 - synonym_weight) + synonym_probs * synonym_weight
              logits = tf.math.log(probs)  # softmax(log(probs)) = probs

      log_probs = tf.nn.log_softmax(logits)
      label_log_probs = -tf.reduce_sum(log_probs * oh_labels, axis=-1)

      numerator = tf.reduce_sum(inputs.masked_lm_weights * label_log_probs)
      denominator = tf.reduce_sum(masked_lm_weights) + 1e-6
      loss = numerator / denominator
      preds = tf.argmax(log_probs, axis=-1, output_type=tf.int32)

      MLMOutput = collections.namedtuple(
          "MLMOutput", ["logits", "probs", "loss", "per_example_loss", "preds"])
      return MLMOutput(
          logits=logits, probs=probs, per_example_loss=label_log_probs,
          loss=loss, preds=preds)
예제 #6
0
    def __init__(self, config: configure_pretraining.PretrainingConfig,
                 features, is_training):
        # Set up model config
        self._config = config
        self._bert_config = training_utils.get_bert_config(config)
        if config.debug:
            self._bert_config.num_hidden_layers = 3
            self._bert_config.hidden_size = 144
            self._bert_config.intermediate_size = 144 * 4
            self._bert_config.num_attention_heads = 4

        # Mask the input
        unmasked_inputs = pretrain_data.features_to_inputs(features)
        masked_inputs = pretrain_helpers.mask(config, unmasked_inputs,
                                              config.mask_prob)

        # Generator
        embedding_size = (self._bert_config.hidden_size
                          if config.embedding_size is None else
                          config.embedding_size)
        cloze_output = None
        if config.uniform_generator:
            # simple generator sampling fakes uniformly at random
            mlm_output = self._get_masked_lm_output(masked_inputs, None)
        elif ((config.electra_objective or config.electric_objective)
              and config.untied_generator):
            generator_config = get_generator_config(config, self._bert_config)
            if config.two_tower_generator:
                # two-tower cloze model generator used for electric
                generator = TwoTowerClozeTransformer(config, generator_config,
                                                     unmasked_inputs,
                                                     is_training,
                                                     embedding_size)
                cloze_output = self._get_cloze_outputs(unmasked_inputs,
                                                       generator)
                mlm_output = get_softmax_output(
                    pretrain_helpers.gather_positions(
                        cloze_output.logits,
                        masked_inputs.masked_lm_positions),
                    masked_inputs.masked_lm_ids,
                    masked_inputs.masked_lm_weights,
                    self._bert_config.vocab_size)
            else:
                # small masked language model generator
                generator = build_transformer(
                    config,
                    masked_inputs,
                    is_training,
                    generator_config,
                    embedding_size=(None if config.untied_generator_embeddings
                                    else embedding_size),
                    untied_embeddings=config.untied_generator_embeddings,
                    scope="generator")
                mlm_output = self._get_masked_lm_output(
                    masked_inputs, generator)
        else:
            # full-sized masked language model generator if using BERT objective or if
            # the generator and discriminator have tied weights
            generator = build_transformer(config,
                                          masked_inputs,
                                          is_training,
                                          self._bert_config,
                                          embedding_size=embedding_size)
            mlm_output = self._get_masked_lm_output(masked_inputs, generator)
        fake_data = self._get_fake_data(masked_inputs, mlm_output.logits)
        self.mlm_output = mlm_output
        self.total_loss = config.gen_weight * (cloze_output.loss
                                               if config.two_tower_generator
                                               else mlm_output.loss)

        # Discriminator
        disc_output = None
        if config.electra_objective or config.electric_objective:
            discriminator = build_transformer(
                config,
                fake_data.inputs,
                is_training,
                self._bert_config,
                reuse=not config.untied_generator,
                embedding_size=embedding_size)
            disc_output = self._get_discriminator_output(
                fake_data.inputs, discriminator, fake_data.is_fake_tokens,
                cloze_output)
            self.total_loss += config.disc_weight * disc_output.loss

        # Evaluation
        eval_fn_inputs = {
            "input_ids": masked_inputs.input_ids,
            "masked_lm_preds": mlm_output.preds,
            "mlm_loss": mlm_output.per_example_loss,
            "masked_lm_ids": masked_inputs.masked_lm_ids,
            "masked_lm_weights": masked_inputs.masked_lm_weights,
            "input_mask": masked_inputs.input_mask
        }
        if config.electra_objective or config.electric_objective:
            eval_fn_inputs.update({
                "disc_loss":
                disc_output.per_example_loss,
                "disc_labels":
                disc_output.labels,
                "disc_probs":
                disc_output.probs,
                "disc_preds":
                disc_output.preds,
                "sampled_tokids":
                tf.argmax(fake_data.sampled_tokens, -1, output_type=tf.int32)
            })
        eval_fn_keys = eval_fn_inputs.keys()
        eval_fn_values = [eval_fn_inputs[k] for k in eval_fn_keys]

        def metric_fn(*args):
            """Computes the loss and accuracy of the model."""
            d = {k: arg for k, arg in zip(eval_fn_keys, args)}
            metrics = dict()
            metrics["masked_lm_accuracy"] = tf.metrics.accuracy(
                labels=tf.reshape(d["masked_lm_ids"], [-1]),
                predictions=tf.reshape(d["masked_lm_preds"], [-1]),
                weights=tf.reshape(d["masked_lm_weights"], [-1]))
            metrics["masked_lm_loss"] = tf.metrics.mean(
                values=tf.reshape(d["mlm_loss"], [-1]),
                weights=tf.reshape(d["masked_lm_weights"], [-1]))
            if config.electra_objective or config.electric_objective:
                metrics["sampled_masked_lm_accuracy"] = tf.metrics.accuracy(
                    labels=tf.reshape(d["masked_lm_ids"], [-1]),
                    predictions=tf.reshape(d["sampled_tokids"], [-1]),
                    weights=tf.reshape(d["masked_lm_weights"], [-1]))
                if config.disc_weight > 0:
                    metrics["disc_loss"] = tf.metrics.mean(d["disc_loss"])
                    metrics["disc_auc"] = tf.metrics.auc(
                        d["disc_labels"] * d["input_mask"],
                        d["disc_probs"] * tf.cast(d["input_mask"], tf.float32))
                    metrics["disc_accuracy"] = tf.metrics.accuracy(
                        labels=d["disc_labels"],
                        predictions=d["disc_preds"],
                        weights=d["input_mask"])
                    metrics["disc_precision"] = tf.metrics.accuracy(
                        labels=d["disc_labels"],
                        predictions=d["disc_preds"],
                        weights=d["disc_preds"] * d["input_mask"])
                    metrics["disc_recall"] = tf.metrics.accuracy(
                        labels=d["disc_labels"],
                        predictions=d["disc_preds"],
                        weights=d["disc_labels"] * d["input_mask"])
            return metrics

        self.eval_metrics = (metric_fn, eval_fn_values)