예제 #1
0
 def _get_fake_data(self, inputs, mlm_logits):
     """Sample from the generator to create corrupted input."""
     inputs = pretrain_helpers.unmask(inputs)
     disallow = tf.one_hot(
         inputs.masked_lm_ids,
         depth=self._bert_config.vocab_size,
         dtype=tf.float32) if self._config.disallow_correct else None
     sampled_tokens = tf.stop_gradient(
         pretrain_helpers.sample_from_softmax(mlm_logits /
                                              self._config.temperature,
                                              disallow=disallow))
     sampled_tokids = tf.argmax(sampled_tokens, -1, output_type=tf.int32)
     updated_input_ids, masked = pretrain_helpers.scatter_update(
         inputs.input_ids, sampled_tokids, inputs.masked_lm_positions)
     if self._config.electric_objective:
         labels = masked
     else:
         labels = masked * (1 - tf.cast(
             tf.equal(updated_input_ids, inputs.input_ids), tf.int32))
     updated_inputs = pretrain_data.get_updated_inputs(
         inputs, input_ids=updated_input_ids)
     FakedData = collections.namedtuple(
         "FakedData", ["inputs", "is_fake_tokens", "sampled_tokens"])
     return FakedData(inputs=updated_inputs,
                      is_fake_tokens=labels,
                      sampled_tokens=sampled_tokens)
예제 #2
0
def _get_fake_data(inputs, mlm_logits):
    """Sample from the generator to create corrupted input."""
    masked_lm_weights = inputs.masked_lm_weights
    inputs = pretrain_helpers.unmask(inputs)
    disallow = None
    sampled_tokens = tf.stop_gradient(
        pretrain_helpers.sample_from_softmax(mlm_logits / 1.0,
                                             disallow=disallow))

    # sampled_tokens: [batch_size, n_pos, n_vocab]
    # mlm_logits: [batch_size, n_pos, n_vocab]
    sampled_tokens_fp32 = tf.cast(sampled_tokens, dtype=tf.float32)
    print(sampled_tokens_fp32, "===sampled_tokens_fp32===")
    # [batch_size, n_pos]
    # mlm_logprobs: [batch_size, n_pos. n_vocab]
    mlm_logprobs = tf.nn.log_softmax(mlm_logits, axis=-1)
    pseudo_logprob = tf.reduce_sum(mlm_logprobs * sampled_tokens_fp32, axis=-1)
    pseudo_logprob *= tf.cast(masked_lm_weights, dtype=tf.float32)
    # [batch_size]
    pseudo_logprob = tf.reduce_sum(pseudo_logprob, axis=-1)
    # [batch_size]
    # pseudo_logprob /= (1e-10+tf.reduce_sum(tf.cast(masked_lm_weights, dtype=tf.float32), axis=-1))
    print("== _get_fake_data pseudo_logprob ==", pseudo_logprob)
    sampled_tokids = tf.argmax(sampled_tokens, -1, output_type=tf.int32)
    updated_input_ids, masked = pretrain_helpers.scatter_update(
        inputs.input_ids, sampled_tokids, inputs.masked_lm_positions)

    labels = masked * (
        1 - tf.cast(tf.equal(updated_input_ids, inputs.input_ids), tf.int32))
    updated_inputs = pretrain_data.get_updated_inputs(
        inputs, input_ids=updated_input_ids)
    FakedData = collections.namedtuple(
        "FakedData",
        ["inputs", "is_fake_tokens", "sampled_tokens", "pseudo_logprob"])
    return FakedData(inputs=updated_inputs,
                     is_fake_tokens=labels,
                     sampled_tokens=sampled_tokens,
                     pseudo_logprob=pseudo_logprob)
예제 #3
0
    def __init__(self, config: configure_pretraining.PretrainingConfig,
                 features, is_training):
        # Set up model config
        self._config = config
        self._bert_config = training_utils.get_bert_config(config)
        if config.debug:
            self._bert_config.num_hidden_layers = 3
            self._bert_config.hidden_size = 144
            self._bert_config.intermediate_size = 144 * 4
            self._bert_config.num_attention_heads = 4

        # Mask the input
        # masked_inputs = pretrain_helpers.mask(
        #     config, pretrain_data.features_to_inputs(features), config.mask_prob)
        # tf.logging.error(f"features to inputs: {pretrain_data.features_to_inputs(features)}")
        # tf.logging.error(f"features: {features}")
        masked_inputs = pretrain_helpers.mask(
            config,
            pretrain_helpers.unmask(
                pretrain_data.features_to_inputs(features)), config.mask_prob)

        # Generator
        embedding_size = (self._bert_config.hidden_size
                          if config.embedding_size is None else
                          config.embedding_size)
        if config.uniform_generator:
            mlm_output = self._get_masked_lm_output(masked_inputs, None)
        elif config.electra_objective and config.untied_generator:
            generator = self._build_transformer(
                masked_inputs,
                is_training,
                bert_config=get_generator_config(config, self._bert_config),
                embedding_size=(None if config.untied_generator_embeddings else
                                embedding_size),
                untied_embeddings=config.untied_generator_embeddings,
                name="generator")
            mlm_output = self._get_masked_lm_output(masked_inputs, generator)
        else:
            generator = self._build_transformer(masked_inputs,
                                                is_training,
                                                embedding_size=embedding_size)
            mlm_output = self._get_masked_lm_output(masked_inputs, generator)
        fake_data = self._get_fake_data(masked_inputs, mlm_output.logits)
        self.mlm_output = mlm_output
        self.total_loss = config.gen_weight * mlm_output.loss

        # Discriminator
        disc_output = None
        if config.electra_objective:
            discriminator = self._build_transformer(
                fake_data.inputs,
                is_training,
                reuse=not config.untied_generator,
                embedding_size=embedding_size)
            disc_output = self._get_discriminator_output(
                fake_data.inputs, discriminator, fake_data.is_fake_tokens)
            self.total_loss += config.disc_weight * disc_output.loss

        tvars = tf.trainable_variables()
        initialized_variable_names = {}
        self.scaffold_fn = None
        if config.init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = modeling.get_assignment_map_from_checkpoint(
                 tvars, config.init_checkpoint)
            if config.use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(config.init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                self.scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(config.init_checkpoint,
                                              assignment_map)

        # Evaluation
        eval_fn_inputs = {
            "input_ids": masked_inputs.input_ids,
            "masked_lm_preds": mlm_output.preds,
            "mlm_loss": mlm_output.per_example_loss,
            "masked_lm_ids": masked_inputs.masked_lm_ids,
            "masked_lm_weights": masked_inputs.masked_lm_weights,
            "input_mask": masked_inputs.input_mask
        }
        if config.electra_objective:
            eval_fn_inputs.update({
                "disc_loss":
                disc_output.per_example_loss,
                "disc_labels":
                disc_output.labels,
                "disc_probs":
                disc_output.probs,
                "disc_preds":
                disc_output.preds,
                "sampled_tokids":
                tf.argmax(fake_data.sampled_tokens, -1, output_type=tf.int32)
            })
        eval_fn_keys = eval_fn_inputs.keys()
        eval_fn_values = [eval_fn_inputs[k] for k in eval_fn_keys]

        def metric_fn(*args):
            """Computes the loss and accuracy of the model."""
            d = {k: arg for k, arg in zip(eval_fn_keys, args)}
            metrics = dict()
            metrics["masked_lm_accuracy"] = tf.metrics.accuracy(
                labels=tf.reshape(d["masked_lm_ids"], [-1]),
                predictions=tf.reshape(d["masked_lm_preds"], [-1]),
                weights=tf.reshape(d["masked_lm_weights"], [-1]))
            metrics["masked_lm_loss"] = tf.metrics.mean(
                values=tf.reshape(d["mlm_loss"], [-1]),
                weights=tf.reshape(d["masked_lm_weights"], [-1]))
            if config.electra_objective:
                metrics["sampled_masked_lm_accuracy"] = tf.metrics.accuracy(
                    labels=tf.reshape(d["masked_lm_ids"], [-1]),
                    predictions=tf.reshape(d["sampled_tokids"], [-1]),
                    weights=tf.reshape(d["masked_lm_weights"], [-1]))
                if config.disc_weight > 0:
                    metrics["disc_loss"] = tf.metrics.mean(d["disc_loss"])
                    metrics["disc_auc"] = tf.metrics.auc(
                        d["disc_labels"] * d["input_mask"],
                        d["disc_probs"] * tf.cast(d["input_mask"], tf.float32))
                    metrics["disc_accuracy"] = tf.metrics.accuracy(
                        labels=d["disc_labels"],
                        predictions=d["disc_preds"],
                        weights=d["input_mask"])
                    metrics["disc_precision"] = tf.metrics.accuracy(
                        labels=d["disc_labels"],
                        predictions=d["disc_preds"],
                        weights=d["disc_preds"] * d["input_mask"])
                    metrics["disc_recall"] = tf.metrics.accuracy(
                        labels=d["disc_labels"],
                        predictions=d["disc_preds"],
                        weights=d["disc_labels"] * d["input_mask"])
            return metrics, scaffold_fn

        self.eval_metrics = (metric_fn, eval_fn_values)