コード例 #1
0
    def __init__(self, config: configure_pretraining.PretrainingConfig,
                 features, ratio, is_training):
        # Set up model config
        self._config = config
        self._bert_config = training_utils.get_bert_config(config)

        embedding_size = (self._bert_config.hidden_size
                          if config.embedding_size is None else
                          config.embedding_size)

        tokenizer = tokenization.FullTokenizer(
            config.vocab_file, do_lower_case=config.do_lower_case)
        self._vocab = tokenizer.vocab
        self._inv_vocab = tokenizer.inv_vocab

        # Mask the input
        inputs = pretrain_data.features_to_inputs(features)
        # Load ratio
        with tf.variable_scope("rw_masking"):
            with tf.variable_scope("ratio"):
                self.ratios = tf.constant(ratio)
                action_prob = tf.nn.embedding_lookup(self.ratios,
                                                     inputs.input_ids)

        log_q, masked_inputs = self._sample_masking_subset(inputs, action_prob)

        # BERT model
        model = self._build_transformer(masked_inputs,
                                        is_training,
                                        reuse=tf.AUTO_REUSE,
                                        embedding_size=embedding_size)
        mlm_output = self._get_masked_lm_output(masked_inputs, model)
        self.total_loss = mlm_output.loss

        # Evaluation`
        eval_fn_inputs = {
            "input_ids": masked_inputs.input_ids,
            "masked_lm_preds": mlm_output.preds,
            "mlm_loss": mlm_output.per_example_loss,
            "masked_lm_ids": masked_inputs.masked_lm_ids,
            "masked_lm_weights": masked_inputs.masked_lm_weights,
            "input_mask": masked_inputs.input_mask
        }
        eval_fn_keys = eval_fn_inputs.keys()
        eval_fn_values = [eval_fn_inputs[k] for k in eval_fn_keys]
        """Computes the loss and accuracy of the model."""
        d = {k: arg for k, arg in zip(eval_fn_keys, eval_fn_values)}
        metrics = dict()
        metrics["masked_lm_accuracy"] = tf.metrics.accuracy(
            labels=tf.reshape(d["masked_lm_ids"], [-1]),
            predictions=tf.reshape(d["masked_lm_preds"], [-1]),
            weights=tf.reshape(d["masked_lm_weights"], [-1]))
        metrics["masked_lm_loss"] = tf.metrics.mean(
            values=tf.reshape(d["mlm_loss"], [-1]),
            weights=tf.reshape(d["masked_lm_weights"], [-1]))
        self.eval_metrics = metrics
コード例 #2
0
def test_data_generator(features):

    # Mask the input
    unmasked_inputs = pretrain_data.features_to_inputs(features)
    masked_inputs = pretrain_helpers.mask(config, unmasked_inputs,
                                          config.mask_prob)

    # Generator
    embedding_size = (768 if config.embedding_size is None else
                      config.embedding_size)
    cloze_output = None

    mlm_output = _get_masked_lm_output(masked_inputs, None)
    fake_data = _get_fake_data(masked_inputs, mlm_output.logits)

    return features, unmasked_inputs, fake_data, masked_inputs
コード例 #3
0
    def __init__(self, config: configure_pretraining.PretrainingConfig,
                 features, is_training):
        # Set up model config
        self._config = config
        self._bert_config = training_utils.get_bert_config(config)
        if config.debug:
            self._bert_config.num_hidden_layers = 3
            self._bert_config.hidden_size = 144
            self._bert_config.intermediate_size = 144 * 4
            self._bert_config.num_attention_heads = 4

        # Mask the input
        masked_inputs = pretrain_helpers.mask(
            config, pretrain_data.features_to_inputs(features),
            config.mask_prob)

        # Generator
        embedding_size = (self._bert_config.hidden_size
                          if config.embedding_size is None else
                          config.embedding_size)
        if config.uniform_generator:
            mlm_output = self._get_masked_lm_output(masked_inputs, None)
        elif config.electra_objective and config.untied_generator:
            generator = self._build_transformer(
                masked_inputs,
                is_training,
                bert_config=get_generator_config(config, self._bert_config),
                embedding_size=(None if config.untied_generator_embeddings else
                                embedding_size),
                untied_embeddings=config.untied_generator_embeddings,
                name="generator")
            mlm_output = self._get_masked_lm_output(masked_inputs, generator)
        else:
            generator = self._build_transformer(masked_inputs,
                                                is_training,
                                                embedding_size=embedding_size)
            mlm_output = self._get_masked_lm_output(masked_inputs, generator)
        fake_data = self._get_fake_data(masked_inputs, mlm_output.logits)
        self.mlm_output = mlm_output
        self.total_loss = config.gen_weight * mlm_output.loss

        # Discriminator
        disc_output = None
        if config.electra_objective:
            discriminator = self._build_transformer(
                fake_data.inputs,
                is_training,
                reuse=not config.untied_generator,
                embedding_size=embedding_size)
            disc_output = self._get_discriminator_output(
                fake_data.inputs, discriminator, fake_data.is_fake_tokens)
            self.total_loss += config.disc_weight * disc_output.loss

        # Evaluation
        eval_fn_inputs = {
            "input_ids": masked_inputs.input_ids,
            "masked_lm_preds": mlm_output.preds,
            "mlm_loss": mlm_output.per_example_loss,
            "masked_lm_ids": masked_inputs.masked_lm_ids,
            "masked_lm_weights": masked_inputs.masked_lm_weights,
            "input_mask": masked_inputs.input_mask
        }
        if config.electra_objective:
            eval_fn_inputs.update({
                "disc_loss":
                disc_output.per_example_loss,
                "disc_labels":
                disc_output.labels,
                "disc_probs":
                disc_output.probs,
                "disc_preds":
                disc_output.preds,
                "sampled_tokids":
                tf.argmax(fake_data.sampled_tokens, -1, output_type=tf.int32)
            })
        eval_fn_keys = eval_fn_inputs.keys()
        eval_fn_values = [eval_fn_inputs[k] for k in eval_fn_keys]

        def metric_fn(*args):
            """Computes the loss and accuracy of the model."""
            d = {k: arg for k, arg in zip(eval_fn_keys, args)}
            metrics = dict()
            metrics["masked_lm_accuracy"] = tf.metrics.accuracy(
                labels=tf.reshape(d["masked_lm_ids"], [-1]),
                predictions=tf.reshape(d["masked_lm_preds"], [-1]),
                weights=tf.reshape(d["masked_lm_weights"], [-1]))
            metrics["masked_lm_loss"] = tf.metrics.mean(
                values=tf.reshape(d["mlm_loss"], [-1]),
                weights=tf.reshape(d["masked_lm_weights"], [-1]))
            if config.electra_objective:
                metrics["sampled_masked_lm_accuracy"] = tf.metrics.accuracy(
                    labels=tf.reshape(d["masked_lm_ids"], [-1]),
                    predictions=tf.reshape(d["sampled_tokids"], [-1]),
                    weights=tf.reshape(d["masked_lm_weights"], [-1]))
                if config.disc_weight > 0:
                    metrics["disc_loss"] = tf.metrics.mean(d["disc_loss"])
                    metrics["disc_auc"] = tf.metrics.auc(
                        d["disc_labels"] * d["input_mask"],
                        d["disc_probs"] * tf.cast(d["input_mask"], tf.float32))
                    metrics["disc_accuracy"] = tf.metrics.accuracy(
                        labels=d["disc_labels"],
                        predictions=d["disc_preds"],
                        weights=d["input_mask"])
                    metrics["disc_precision"] = tf.metrics.accuracy(
                        labels=d["disc_labels"],
                        predictions=d["disc_preds"],
                        weights=d["disc_preds"] * d["input_mask"])
                    metrics["disc_recall"] = tf.metrics.accuracy(
                        labels=d["disc_labels"],
                        predictions=d["disc_preds"],
                        weights=d["disc_labels"] * d["input_mask"])
            return metrics

        self.eval_metrics = (metric_fn, eval_fn_values)
コード例 #4
0
    def __init__(self, config: configure_pretraining.PretrainingConfig,
                 features, is_training):
        # Set up model config
        self._config = config
        self._bert_config = training_utils.get_bert_config(config)
        # if config.debug:
        #   self._bert_config.num_hidden_layers = 3
        #   self._bert_config.hidden_size = 144
        #   self._bert_config.intermediate_size = 144 * 4
        #   self._bert_config.num_attention_heads = 4

        embedding_size = (self._bert_config.hidden_size
                          if config.embedding_size is None else
                          config.embedding_size)

        # Mask the input
        inputs = pretrain_data.features_to_inputs(features)
        proposal_distribution = 1.0
        if config.masking_strategy == pretrain_helpers.ENTROPY_STRATEGY:
            old_model = self._build_transformer(inputs,
                                                is_training,
                                                embedding_size=embedding_size)
            entropy_output = self._get_entropy_output(inputs, old_model)
            proposal_distribution = entropy_output.entropy

        masked_inputs = pretrain_helpers.mask(
            config,
            pretrain_data.features_to_inputs(features),
            config.mask_prob,
            proposal_distribution=proposal_distribution)

        # BERT model
        model = self._build_transformer(masked_inputs,
                                        is_training,
                                        reuse=tf.AUTO_REUSE,
                                        embedding_size=embedding_size)
        mlm_output = self._get_masked_lm_output(masked_inputs, model)
        self.total_loss = mlm_output.loss

        # Evaluation
        eval_fn_inputs = {
            "input_ids": masked_inputs.input_ids,
            "masked_lm_preds": mlm_output.preds,
            "mlm_loss": mlm_output.per_example_loss,
            "masked_lm_ids": masked_inputs.masked_lm_ids,
            "masked_lm_weights": masked_inputs.masked_lm_weights,
            "input_mask": masked_inputs.input_mask
        }
        eval_fn_keys = eval_fn_inputs.keys()
        eval_fn_values = [eval_fn_inputs[k] for k in eval_fn_keys]
        """Computes the loss and accuracy of the model."""
        d = {k: arg for k, arg in zip(eval_fn_keys, eval_fn_values)}
        metrics = dict()
        metrics["masked_lm_accuracy"] = tf.metrics.accuracy(
            labels=tf.reshape(d["masked_lm_ids"], [-1]),
            predictions=tf.reshape(d["masked_lm_preds"], [-1]),
            weights=tf.reshape(d["masked_lm_weights"], [-1]))
        metrics["masked_lm_loss"] = tf.metrics.mean(
            values=tf.reshape(d["mlm_loss"], [-1]),
            weights=tf.reshape(d["masked_lm_weights"], [-1]))
        self.eval_metrics = metrics
コード例 #5
0
    def __init__(self, config: configure_pretraining.PretrainingConfig,
                 features, is_training):
        # Set up model config
        self._config = config
        self._bert_config = training_utils.get_bert_config(config)
        self._teacher_config = training_utils.get_teacher_config(config)

        embedding_size = (self._bert_config.hidden_size
                          if config.embedding_size is None else
                          config.embedding_size)

        tokenizer = tokenization.FullTokenizer(
            config.vocab_file, do_lower_case=config.do_lower_case)
        self._vocab = tokenizer.vocab
        self._inv_vocab = tokenizer.inv_vocab

        # Mask the input
        inputs = pretrain_data.features_to_inputs(features)
        old_model = self._build_transformer(inputs,
                                            is_training,
                                            embedding_size=embedding_size)
        input_states = old_model.get_sequence_output()
        input_states = tf.stop_gradient(input_states)

        teacher_output = self._build_teacher(input_states,
                                             inputs,
                                             is_training,
                                             embedding_size=embedding_size)
        # calculate the proposal distribution

        action_prob = teacher_output.action_probs  #pi(x_i)

        coin_toss = tf.random.uniform([])
        log_q, masked_inputs = self._sample_masking_subset(inputs, action_prob)
        if config.masking_strategy == pretrain_helpers.MIX_ADV_STRATEGY:
            random_masked_input = pretrain_helpers.mask(
                config, pretrain_data.features_to_inputs(features),
                config.mask_prob)
            B, L = modeling.get_shape_list(inputs.input_ids)
            N = config.max_predictions_per_seq
            strategy_prob = tf.random.uniform([B])
            strategy_prob = tf.expand_dims(
                tf.cast(tf.greater(strategy_prob, 0.5), tf.int32), 1)
            l_strategy_prob = tf.tile(strategy_prob, [1, L])
            n_strategy_prob = tf.tile(strategy_prob, [1, N])
            mix_input_ids = masked_inputs.input_ids * l_strategy_prob + random_masked_input.input_ids * (
                1 - l_strategy_prob)
            mix_masked_lm_positions = masked_inputs.masked_lm_positions * n_strategy_prob + random_masked_input.masked_lm_positions * (
                1 - n_strategy_prob)
            mix_masked_lm_ids = masked_inputs.masked_lm_ids * n_strategy_prob + random_masked_input.masked_lm_ids * (
                1 - n_strategy_prob)
            n_strategy_prob = tf.cast(n_strategy_prob, tf.float32)
            mix_masked_lm_weights = masked_inputs.masked_lm_weights * n_strategy_prob + random_masked_input.masked_lm_weights * (
                1 - n_strategy_prob)
            mix_masked_inputs = pretrain_data.get_updated_inputs(
                inputs,
                input_ids=tf.stop_gradient(mix_input_ids),
                masked_lm_positions=mix_masked_lm_positions,
                masked_lm_ids=mix_masked_lm_ids,
                masked_lm_weights=mix_masked_lm_weights,
                tag_ids=inputs.tag_ids)
            masked_inputs = mix_masked_inputs

        # BERT model
        model = self._build_transformer(masked_inputs,
                                        is_training,
                                        reuse=tf.AUTO_REUSE,
                                        embedding_size=embedding_size)
        mlm_output = self._get_masked_lm_output(masked_inputs, model)
        self.total_loss = mlm_output.loss

        # Teacher reward is the -log p(x_S|x;B)
        reward = tf.stop_gradient(
            tf.reduce_mean(mlm_output.per_example_loss, 1))
        self._baseline = tf.reduce_mean(reward, -1)
        self._std = tf.math.reduce_std(reward, -1)

        # Calculate teacher loss
        def compute_teacher_loss(log_q, reward, baseline, std):
            advantage = tf.abs((reward - baseline) / std)
            advantage = tf.stop_gradient(advantage)
            log_q = tf.Print(log_q, [log_q], "log_q: ")
            teacher_loss = tf.reduce_mean(-log_q * advantage)
            return teacher_loss

        teacher_loss = tf.cond(
            coin_toss < 0.1, lambda: compute_teacher_loss(
                log_q, reward, self._baseline, self._std),
            lambda: tf.constant(0.0))
        self.total_loss = mlm_output.loss + teacher_loss
        self.teacher_loss = teacher_loss
        self.mlm_loss = mlm_output.loss

        # Evaluation`
        eval_fn_inputs = {
            "input_ids": masked_inputs.input_ids,
            "masked_lm_preds": mlm_output.preds,
            "mlm_loss": mlm_output.per_example_loss,
            "masked_lm_ids": masked_inputs.masked_lm_ids,
            "masked_lm_weights": masked_inputs.masked_lm_weights,
            "input_mask": masked_inputs.input_mask
        }
        eval_fn_keys = eval_fn_inputs.keys()
        eval_fn_values = [eval_fn_inputs[k] for k in eval_fn_keys]
        """Computes the loss and accuracy of the model."""
        d = {k: arg for k, arg in zip(eval_fn_keys, eval_fn_values)}
        metrics = dict()
        metrics["masked_lm_accuracy"] = tf.metrics.accuracy(
            labels=tf.reshape(d["masked_lm_ids"], [-1]),
            predictions=tf.reshape(d["masked_lm_preds"], [-1]),
            weights=tf.reshape(d["masked_lm_weights"], [-1]))
        metrics["masked_lm_loss"] = tf.metrics.mean(
            values=tf.reshape(d["mlm_loss"], [-1]),
            weights=tf.reshape(d["masked_lm_weights"], [-1]))
        self.eval_metrics = metrics
コード例 #6
0
    def __init__(self, config: PretrainingConfig, features, is_training,
                 init_checkpoint):
        # Set up model config
        self._config = config
        self._bert_config = training_utils.get_bert_config(config)
        if config.debug:
            self._bert_config.num_hidden_layers = 3
            self._bert_config.hidden_size = 144
            self._bert_config.intermediate_size = 144 * 4
            self._bert_config.num_attention_heads = 4

        compute_type = modeling.infer_dtype(config.use_fp16)
        custom_getter = modeling.get_custom_getter(compute_type)

        with tf.variable_scope(tf.get_variable_scope(),
                               custom_getter=custom_getter):
            # Mask the input
            masked_inputs = pretrain_helpers.mask(
                config, pretrain_data.features_to_inputs(features),
                config.mask_prob)

            # Generator
            embedding_size = self._bert_config.hidden_size if config.embedding_size is None else config.embedding_size
            if config.uniform_generator:
                mlm_output = self._get_masked_lm_output(masked_inputs, None)
            elif config.electra_objective and config.untied_generator:
                generator = self._build_transformer(
                    name="generator",
                    inputs=masked_inputs,
                    is_training=is_training,
                    use_fp16=config.use_fp16,
                    bert_config=get_generator_config(config,
                                                     self._bert_config),
                    embedding_size=None
                    if config.untied_generator_embeddings else embedding_size,
                    untied_embeddings=config.untied_generator_embeddings)
                mlm_output = self._get_masked_lm_output(
                    masked_inputs, generator)
            else:
                generator = self._build_transformer(
                    name="electra",
                    inputs=masked_inputs,
                    is_training=is_training,
                    use_fp16=config.use_fp16,
                    embedding_size=embedding_size)
                mlm_output = self._get_masked_lm_output(
                    masked_inputs, generator)
            fake_data = self._get_fake_data(masked_inputs, mlm_output.logits)
            self.mlm_output = mlm_output
            self.total_loss = config.gen_weight * mlm_output.loss

            utils.log("Generator is built!")

            # Discriminator
            self.disc_output = None
            if config.electra_objective:
                discriminator = self._build_transformer(
                    name="electra",
                    inputs=fake_data.inputs,
                    is_training=is_training,
                    use_fp16=config.use_fp16,
                    embedding_size=embedding_size)
                utils.log("Discriminator is built!")
                self.disc_output = self._get_discriminator_output(
                    inputs=fake_data.inputs,
                    discriminator=discriminator,
                    labels=fake_data.is_fake_tokens)
                self.total_loss += config.disc_weight * self.disc_output.loss

        if init_checkpoint and hvd.rank() == 0:
            print("Loading checkpoint", init_checkpoint)
            assignment_map, _ = modeling.get_assignment_map_from_checkpoint(
                tvars=tf.trainable_variables(),
                init_checkpoint=init_checkpoint,
                prefix="")
            tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        # Evaluation
        eval_fn_inputs = {
            "input_ids": masked_inputs.input_ids,
            "masked_lm_preds": mlm_output.preds,
            "mlm_loss": mlm_output.per_example_loss,
            "masked_lm_ids": masked_inputs.masked_lm_ids,
            "masked_lm_weights": masked_inputs.masked_lm_weights,
            "input_mask": masked_inputs.input_mask
        }
        if config.electra_objective:
            eval_fn_inputs.update({
                "disc_loss":
                self.disc_output.per_example_loss,
                "disc_labels":
                self.disc_output.labels,
                "disc_probs":
                self.disc_output.probs,
                "disc_preds":
                self.disc_output.preds,
                "sampled_tokids":
                tf.argmax(fake_data.sampled_tokens, -1, output_type=tf.int32)
            })
        eval_fn_keys = eval_fn_inputs.keys()
        eval_fn_values = [eval_fn_inputs[k] for k in eval_fn_keys]

        def metric_fn(*args):
            """Computes the loss and accuracy of the model."""
            d = dict(zip(eval_fn_keys, args))
            metrics = dict()
            metrics["masked_lm_accuracy"] = tf.metrics.accuracy(
                labels=tf.reshape(d["masked_lm_ids"], [-1]),
                predictions=tf.reshape(d["masked_lm_preds"], [-1]),
                weights=tf.reshape(d["masked_lm_weights"], [-1]))
            metrics["masked_lm_loss"] = tf.metrics.mean(
                values=tf.reshape(d["mlm_loss"], [-1]),
                weights=tf.reshape(d["masked_lm_weights"], [-1]))
            if config.electra_objective:
                metrics["sampled_masked_lm_accuracy"] = tf.metrics.accuracy(
                    labels=tf.reshape(d["masked_lm_ids"], [-1]),
                    predictions=tf.reshape(d["sampled_tokids"], [-1]),
                    weights=tf.reshape(d["masked_lm_weights"], [-1]))
                if config.disc_weight > 0:
                    metrics["disc_loss"] = tf.metrics.mean(d["disc_loss"])
                    metrics["disc_auc"] = tf.metrics.auc(
                        d["disc_labels"] * d["input_mask"],
                        d["disc_probs"] * tf.cast(d["input_mask"], tf.float32))
                    metrics["disc_accuracy"] = tf.metrics.accuracy(
                        labels=d["disc_labels"],
                        predictions=d["disc_preds"],
                        weights=d["input_mask"])
                    metrics["disc_precision"] = tf.metrics.accuracy(
                        labels=d["disc_labels"],
                        predictions=d["disc_preds"],
                        weights=d["disc_preds"] * d["input_mask"])
                    metrics["disc_recall"] = tf.metrics.accuracy(
                        labels=d["disc_labels"],
                        predictions=d["disc_preds"],
                        weights=d["disc_labels"] * d["input_mask"])

            return metrics

        self.eval_metrics = (metric_fn, eval_fn_values)
コード例 #7
0
ファイル: run_pretraining.py プロジェクト: zihua/electra
    def __init__(self, config: configure_pretraining.PretrainingConfig,
                 features, is_training):
        # Set up model config
        self._config = config
        self._bert_config = training_utils.get_bert_config(config)
        if config.debug:
            self._bert_config.num_hidden_layers = 3
            self._bert_config.hidden_size = 144
            self._bert_config.intermediate_size = 144 * 4
            self._bert_config.num_attention_heads = 4

        # Mask the input
        unmasked_inputs = pretrain_data.features_to_inputs(features)
        masked_inputs = pretrain_helpers.mask(config, unmasked_inputs,
                                              config.mask_prob)

        # Generator
        embedding_size = (self._bert_config.hidden_size
                          if config.embedding_size is None else
                          config.embedding_size)
        cloze_output = None
        if config.uniform_generator:
            # simple generator sampling fakes uniformly at random
            mlm_output = self._get_masked_lm_output(masked_inputs, None)
        elif ((config.electra_objective or config.electric_objective)
              and config.untied_generator):
            generator_config = get_generator_config(config, self._bert_config)
            if config.two_tower_generator:
                # two-tower cloze model generator used for electric
                generator = TwoTowerClozeTransformer(config, generator_config,
                                                     unmasked_inputs,
                                                     is_training,
                                                     embedding_size)
                cloze_output = self._get_cloze_outputs(unmasked_inputs,
                                                       generator)
                mlm_output = get_softmax_output(
                    pretrain_helpers.gather_positions(
                        cloze_output.logits,
                        masked_inputs.masked_lm_positions),
                    masked_inputs.masked_lm_ids,
                    masked_inputs.masked_lm_weights,
                    self._bert_config.vocab_size)
            else:
                # small masked language model generator
                generator = build_transformer(
                    config,
                    masked_inputs,
                    is_training,
                    generator_config,
                    embedding_size=(None if config.untied_generator_embeddings
                                    else embedding_size),
                    untied_embeddings=config.untied_generator_embeddings,
                    scope="generator")
                mlm_output = self._get_masked_lm_output(
                    masked_inputs, generator)
        else:
            # full-sized masked language model generator if using BERT objective or if
            # the generator and discriminator have tied weights
            generator = build_transformer(config,
                                          masked_inputs,
                                          is_training,
                                          self._bert_config,
                                          embedding_size=embedding_size)
            mlm_output = self._get_masked_lm_output(masked_inputs, generator)
        fake_data = self._get_fake_data(masked_inputs, mlm_output.logits)
        self.mlm_output = mlm_output
        self.total_loss = config.gen_weight * (cloze_output.loss
                                               if config.two_tower_generator
                                               else mlm_output.loss)

        # Discriminator
        disc_output = None
        if config.electra_objective or config.electric_objective:
            discriminator = build_transformer(
                config,
                fake_data.inputs,
                is_training,
                self._bert_config,
                reuse=not config.untied_generator,
                embedding_size=embedding_size)
            disc_output = self._get_discriminator_output(
                fake_data.inputs, discriminator, fake_data.is_fake_tokens,
                cloze_output)
            self.total_loss += config.disc_weight * disc_output.loss

        # Evaluation
        eval_fn_inputs = {
            "input_ids": masked_inputs.input_ids,
            "masked_lm_preds": mlm_output.preds,
            "mlm_loss": mlm_output.per_example_loss,
            "masked_lm_ids": masked_inputs.masked_lm_ids,
            "masked_lm_weights": masked_inputs.masked_lm_weights,
            "input_mask": masked_inputs.input_mask
        }
        if config.electra_objective or config.electric_objective:
            eval_fn_inputs.update({
                "disc_loss":
                disc_output.per_example_loss,
                "disc_labels":
                disc_output.labels,
                "disc_probs":
                disc_output.probs,
                "disc_preds":
                disc_output.preds,
                "sampled_tokids":
                tf.argmax(fake_data.sampled_tokens, -1, output_type=tf.int32)
            })
        eval_fn_keys = eval_fn_inputs.keys()
        eval_fn_values = [eval_fn_inputs[k] for k in eval_fn_keys]

        def metric_fn(*args):
            """Computes the loss and accuracy of the model."""
            d = {k: arg for k, arg in zip(eval_fn_keys, args)}
            metrics = dict()
            metrics["masked_lm_accuracy"] = tf.metrics.accuracy(
                labels=tf.reshape(d["masked_lm_ids"], [-1]),
                predictions=tf.reshape(d["masked_lm_preds"], [-1]),
                weights=tf.reshape(d["masked_lm_weights"], [-1]))
            metrics["masked_lm_loss"] = tf.metrics.mean(
                values=tf.reshape(d["mlm_loss"], [-1]),
                weights=tf.reshape(d["masked_lm_weights"], [-1]))
            if config.electra_objective or config.electric_objective:
                metrics["sampled_masked_lm_accuracy"] = tf.metrics.accuracy(
                    labels=tf.reshape(d["masked_lm_ids"], [-1]),
                    predictions=tf.reshape(d["sampled_tokids"], [-1]),
                    weights=tf.reshape(d["masked_lm_weights"], [-1]))
                if config.disc_weight > 0:
                    metrics["disc_loss"] = tf.metrics.mean(d["disc_loss"])
                    metrics["disc_auc"] = tf.metrics.auc(
                        d["disc_labels"] * d["input_mask"],
                        d["disc_probs"] * tf.cast(d["input_mask"], tf.float32))
                    metrics["disc_accuracy"] = tf.metrics.accuracy(
                        labels=d["disc_labels"],
                        predictions=d["disc_preds"],
                        weights=d["input_mask"])
                    metrics["disc_precision"] = tf.metrics.accuracy(
                        labels=d["disc_labels"],
                        predictions=d["disc_preds"],
                        weights=d["disc_preds"] * d["input_mask"])
                    metrics["disc_recall"] = tf.metrics.accuracy(
                        labels=d["disc_labels"],
                        predictions=d["disc_preds"],
                        weights=d["disc_labels"] * d["input_mask"])
            return metrics

        self.eval_metrics = (metric_fn, eval_fn_values)