示例#1
0
    def __init__(self, config: configure_finetuning.FinetuningConfig, tasks,
                 is_training, features, num_train_steps):
        # Create a shared transformer encoder
        bert_config = training_utils.get_bert_config(config)
        self.bert_config = bert_config
        if config.debug:
            bert_config.num_hidden_layers = 3
            bert_config.hidden_size = 144
            bert_config.intermediate_size = 144 * 4
            bert_config.num_attention_heads = 4
        assert config.max_seq_length <= bert_config.max_position_embeddings
        bert_model = modeling.BertModel(bert_config=bert_config,
                                        is_training=is_training,
                                        input_ids=features["input_ids"],
                                        input_mask=features["input_mask"],
                                        token_type_ids=features["segment_ids"],
                                        use_one_hot_embeddings=config.use_tpu,
                                        embedding_size=config.embedding_size)
        percent_done = (
            tf.cast(tf.train.get_or_create_global_step(), tf.float32) /
            tf.cast(num_train_steps, tf.float32))

        # Add specific tasks
        self.outputs = {"task_id": features["task_id"]}
        losses = []
        for task in tasks:
            with tf.variable_scope("task_specific/" + task.name):
                task_losses, task_outputs = task.get_prediction_module(
                    bert_model, features, is_training, percent_done)
                losses.append(task_losses)
                self.outputs[task.name] = task_outputs
        self.loss = tf.reduce_sum(
            tf.stack(losses, -1) *
            tf.one_hot(features["task_id"], len(config.task_names)))
示例#2
0
    def __init__(self, config: configure_pretraining.PretrainingConfig,
                 features, ratio, is_training):
        # Set up model config
        self._config = config
        self._bert_config = training_utils.get_bert_config(config)

        embedding_size = (self._bert_config.hidden_size
                          if config.embedding_size is None else
                          config.embedding_size)

        tokenizer = tokenization.FullTokenizer(
            config.vocab_file, do_lower_case=config.do_lower_case)
        self._vocab = tokenizer.vocab
        self._inv_vocab = tokenizer.inv_vocab

        # Mask the input
        inputs = pretrain_data.features_to_inputs(features)
        # Load ratio
        with tf.variable_scope("rw_masking"):
            with tf.variable_scope("ratio"):
                self.ratios = tf.constant(ratio)
                action_prob = tf.nn.embedding_lookup(self.ratios,
                                                     inputs.input_ids)

        log_q, masked_inputs = self._sample_masking_subset(inputs, action_prob)

        # BERT model
        model = self._build_transformer(masked_inputs,
                                        is_training,
                                        reuse=tf.AUTO_REUSE,
                                        embedding_size=embedding_size)
        mlm_output = self._get_masked_lm_output(masked_inputs, model)
        self.total_loss = mlm_output.loss

        # Evaluation`
        eval_fn_inputs = {
            "input_ids": masked_inputs.input_ids,
            "masked_lm_preds": mlm_output.preds,
            "mlm_loss": mlm_output.per_example_loss,
            "masked_lm_ids": masked_inputs.masked_lm_ids,
            "masked_lm_weights": masked_inputs.masked_lm_weights,
            "input_mask": masked_inputs.input_mask
        }
        eval_fn_keys = eval_fn_inputs.keys()
        eval_fn_values = [eval_fn_inputs[k] for k in eval_fn_keys]
        """Computes the loss and accuracy of the model."""
        d = {k: arg for k, arg in zip(eval_fn_keys, eval_fn_values)}
        metrics = dict()
        metrics["masked_lm_accuracy"] = tf.metrics.accuracy(
            labels=tf.reshape(d["masked_lm_ids"], [-1]),
            predictions=tf.reshape(d["masked_lm_preds"], [-1]),
            weights=tf.reshape(d["masked_lm_weights"], [-1]))
        metrics["masked_lm_loss"] = tf.metrics.mean(
            values=tf.reshape(d["mlm_loss"], [-1]),
            weights=tf.reshape(d["masked_lm_weights"], [-1]))
        self.eval_metrics = metrics
示例#3
0
    def __init__(self, config: configure_finetuning.FinetuningConfig, tasks,
                 is_training, features, num_train_steps):
        # Create a shared transformer encoder
        bert_config = training_utils.get_bert_config(config)
        self.bert_config = bert_config
        if config.debug:
            bert_config.num_hidden_layers = 3
            bert_config.hidden_size = 144
            bert_config.intermediate_size = 144 * 4
            bert_config.num_attention_heads = 4

        # multi-choice mrc
        if any([isinstance(x, qa_tasks.MQATask) for x in tasks]):
            seq_len = config.max_seq_length
            assert seq_len <= bert_config.max_position_embeddings
            bs, total_len = modeling.get_shape_list(features["input_ids"],
                                                    expected_rank=2)
            to_shape = [
                bs * config.max_options_num * config.evidences_top_k, seq_len
            ]
            bert_model = modeling.BertModel(
                bert_config=bert_config,
                is_training=is_training,
                input_ids=tf.reshape(features["input_ids"], to_shape),
                input_mask=tf.reshape(features["input_mask"], to_shape),
                token_type_ids=tf.reshape(features["segment_ids"], to_shape),
                use_one_hot_embeddings=config.use_tpu,
                embedding_size=config.embedding_size)
        else:
            assert config.max_seq_length <= bert_config.max_position_embeddings
            bert_model = modeling.BertModel(
                bert_config=bert_config,
                is_training=is_training,
                input_ids=features["input_ids"],
                input_mask=features["input_mask"],
                token_type_ids=features["segment_ids"],
                use_one_hot_embeddings=config.use_tpu,
                embedding_size=config.embedding_size)
        percent_done = (
            tf.cast(tf.train.get_or_create_global_step(), tf.float32) /
            tf.cast(num_train_steps, tf.float32))

        # Add specific tasks
        self.outputs = {"task_id": features["task_id"]}
        losses = []
        for task in tasks:
            with tf.variable_scope("task_specific/" + task.name):
                task_losses, task_outputs = task.get_prediction_module(
                    bert_model, features, is_training, percent_done)
                losses.append(task_losses)
                self.outputs[task.name] = task_outputs
        self.loss = tf.reduce_sum(
            tf.stack(losses, -1) *
            tf.one_hot(features["task_id"], len(config.task_names)))
示例#4
0
    def bert_module_fn(is_training):
        """Spec function for a token embedding module."""

        input_ids = tf.placeholder(shape=[None, None],
                                   dtype=tf.int32,
                                   name="input_ids")
        input_mask = tf.placeholder(shape=[None, None],
                                    dtype=tf.int32,
                                    name="input_mask")
        token_type = tf.placeholder(shape=[None, None],
                                    dtype=tf.int32,
                                    name="segment_ids")

        bert_config = training_utils.get_bert_config(config)

        model = modeling.BertModel(bert_config=bert_config,
                                   is_training=is_training,
                                   input_ids=input_ids,
                                   input_mask=input_mask,
                                   token_type_ids=token_type,
                                   use_one_hot_embeddings=use_tpu,
                                   embedding_size=config.embedding_size)

        seq_output = model.sequence_output
        pool_output = model.pooled_output

        vocab_file = tf.constant(value=vocab_path,
                                 dtype=tf.string,
                                 name="vocab_file")
        lower_case = tf.constant(do_lower_case)

        tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, vocab_file)

        input_map = {
            "input_ids": input_ids,
            "input_mask": input_mask,
            "segment_ids": token_type
        }

        output_map = {
            "pooled_output": pool_output,
            "sequence_output": seq_output
        }

        output_info_map = {
            "vocab_file": vocab_file,
            "do_lower_case": lower_case
        }

        hub.add_signature(name="tokens", inputs=input_map, outputs=output_map)
        hub.add_signature(name="tokenization_info",
                          inputs={},
                          outputs=output_info_map)
示例#5
0
    def __init__(self, config: configure_pretraining.PretrainingConfig,
                 features, is_training):
        # Set up model config
        self._config = config
        self._bert_config = training_utils.get_bert_config(config)
        if config.debug:
            self._bert_config.num_hidden_layers = 3
            self._bert_config.hidden_size = 144
            self._bert_config.intermediate_size = 144 * 4
            self._bert_config.num_attention_heads = 4

        # Mask the input
        masked_inputs = pretrain_helpers.mask(
            config, pretrain_data.features_to_inputs(features),
            config.mask_prob)

        # Generator
        embedding_size = (self._bert_config.hidden_size
                          if config.embedding_size is None else
                          config.embedding_size)
        if config.uniform_generator:
            mlm_output = self._get_masked_lm_output(masked_inputs, None)
        elif config.electra_objective and config.untied_generator:
            generator = self._build_transformer(
                masked_inputs,
                is_training,
                bert_config=get_generator_config(config, self._bert_config),
                embedding_size=(None if config.untied_generator_embeddings else
                                embedding_size),
                untied_embeddings=config.untied_generator_embeddings,
                name="generator")
            mlm_output = self._get_masked_lm_output(masked_inputs, generator)
        else:
            generator = self._build_transformer(masked_inputs,
                                                is_training,
                                                embedding_size=embedding_size)
            mlm_output = self._get_masked_lm_output(masked_inputs, generator)
        fake_data = self._get_fake_data(masked_inputs, mlm_output.logits)
        self.mlm_output = mlm_output
        self.total_loss = config.gen_weight * mlm_output.loss

        # Discriminator
        disc_output = None
        if config.electra_objective:
            discriminator = self._build_transformer(
                fake_data.inputs,
                is_training,
                reuse=not config.untied_generator,
                embedding_size=embedding_size)
            disc_output = self._get_discriminator_output(
                fake_data.inputs, discriminator, fake_data.is_fake_tokens)
            self.total_loss += config.disc_weight * disc_output.loss

        # Evaluation
        eval_fn_inputs = {
            "input_ids": masked_inputs.input_ids,
            "masked_lm_preds": mlm_output.preds,
            "mlm_loss": mlm_output.per_example_loss,
            "masked_lm_ids": masked_inputs.masked_lm_ids,
            "masked_lm_weights": masked_inputs.masked_lm_weights,
            "input_mask": masked_inputs.input_mask
        }
        if config.electra_objective:
            eval_fn_inputs.update({
                "disc_loss":
                disc_output.per_example_loss,
                "disc_labels":
                disc_output.labels,
                "disc_probs":
                disc_output.probs,
                "disc_preds":
                disc_output.preds,
                "sampled_tokids":
                tf.argmax(fake_data.sampled_tokens, -1, output_type=tf.int32)
            })
        eval_fn_keys = eval_fn_inputs.keys()
        eval_fn_values = [eval_fn_inputs[k] for k in eval_fn_keys]

        def metric_fn(*args):
            """Computes the loss and accuracy of the model."""
            d = {k: arg for k, arg in zip(eval_fn_keys, args)}
            metrics = dict()
            metrics["masked_lm_accuracy"] = tf.metrics.accuracy(
                labels=tf.reshape(d["masked_lm_ids"], [-1]),
                predictions=tf.reshape(d["masked_lm_preds"], [-1]),
                weights=tf.reshape(d["masked_lm_weights"], [-1]))
            metrics["masked_lm_loss"] = tf.metrics.mean(
                values=tf.reshape(d["mlm_loss"], [-1]),
                weights=tf.reshape(d["masked_lm_weights"], [-1]))
            if config.electra_objective:
                metrics["sampled_masked_lm_accuracy"] = tf.metrics.accuracy(
                    labels=tf.reshape(d["masked_lm_ids"], [-1]),
                    predictions=tf.reshape(d["sampled_tokids"], [-1]),
                    weights=tf.reshape(d["masked_lm_weights"], [-1]))
                if config.disc_weight > 0:
                    metrics["disc_loss"] = tf.metrics.mean(d["disc_loss"])
                    metrics["disc_auc"] = tf.metrics.auc(
                        d["disc_labels"] * d["input_mask"],
                        d["disc_probs"] * tf.cast(d["input_mask"], tf.float32))
                    metrics["disc_accuracy"] = tf.metrics.accuracy(
                        labels=d["disc_labels"],
                        predictions=d["disc_preds"],
                        weights=d["input_mask"])
                    metrics["disc_precision"] = tf.metrics.accuracy(
                        labels=d["disc_labels"],
                        predictions=d["disc_preds"],
                        weights=d["disc_preds"] * d["input_mask"])
                    metrics["disc_recall"] = tf.metrics.accuracy(
                        labels=d["disc_labels"],
                        predictions=d["disc_preds"],
                        weights=d["disc_labels"] * d["input_mask"])
            return metrics

        self.eval_metrics = (metric_fn, eval_fn_values)
示例#6
0
    def __init__(self, config: configure_pretraining.PretrainingConfig,
                 features, is_training):
        # Set up model config
        self._config = config
        self._bert_config = training_utils.get_bert_config(config)
        # if config.debug:
        #   self._bert_config.num_hidden_layers = 3
        #   self._bert_config.hidden_size = 144
        #   self._bert_config.intermediate_size = 144 * 4
        #   self._bert_config.num_attention_heads = 4

        embedding_size = (self._bert_config.hidden_size
                          if config.embedding_size is None else
                          config.embedding_size)

        # Mask the input
        inputs = pretrain_data.features_to_inputs(features)
        proposal_distribution = 1.0
        if config.masking_strategy == pretrain_helpers.ENTROPY_STRATEGY:
            old_model = self._build_transformer(inputs,
                                                is_training,
                                                embedding_size=embedding_size)
            entropy_output = self._get_entropy_output(inputs, old_model)
            proposal_distribution = entropy_output.entropy

        masked_inputs = pretrain_helpers.mask(
            config,
            pretrain_data.features_to_inputs(features),
            config.mask_prob,
            proposal_distribution=proposal_distribution)

        # BERT model
        model = self._build_transformer(masked_inputs,
                                        is_training,
                                        reuse=tf.AUTO_REUSE,
                                        embedding_size=embedding_size)
        mlm_output = self._get_masked_lm_output(masked_inputs, model)
        self.total_loss = mlm_output.loss

        # Evaluation
        eval_fn_inputs = {
            "input_ids": masked_inputs.input_ids,
            "masked_lm_preds": mlm_output.preds,
            "mlm_loss": mlm_output.per_example_loss,
            "masked_lm_ids": masked_inputs.masked_lm_ids,
            "masked_lm_weights": masked_inputs.masked_lm_weights,
            "input_mask": masked_inputs.input_mask
        }
        eval_fn_keys = eval_fn_inputs.keys()
        eval_fn_values = [eval_fn_inputs[k] for k in eval_fn_keys]
        """Computes the loss and accuracy of the model."""
        d = {k: arg for k, arg in zip(eval_fn_keys, eval_fn_values)}
        metrics = dict()
        metrics["masked_lm_accuracy"] = tf.metrics.accuracy(
            labels=tf.reshape(d["masked_lm_ids"], [-1]),
            predictions=tf.reshape(d["masked_lm_preds"], [-1]),
            weights=tf.reshape(d["masked_lm_weights"], [-1]))
        metrics["masked_lm_loss"] = tf.metrics.mean(
            values=tf.reshape(d["mlm_loss"], [-1]),
            weights=tf.reshape(d["masked_lm_weights"], [-1]))
        self.eval_metrics = metrics
示例#7
0
    def __init__(self, config: configure_pretraining.PretrainingConfig,
                 features, is_training):
        # Set up model config
        self._config = config
        self._bert_config = training_utils.get_bert_config(config)
        self._teacher_config = training_utils.get_teacher_config(config)

        embedding_size = (self._bert_config.hidden_size
                          if config.embedding_size is None else
                          config.embedding_size)

        tokenizer = tokenization.FullTokenizer(
            config.vocab_file, do_lower_case=config.do_lower_case)
        self._vocab = tokenizer.vocab
        self._inv_vocab = tokenizer.inv_vocab

        # Mask the input
        inputs = pretrain_data.features_to_inputs(features)
        old_model = self._build_transformer(inputs,
                                            is_training,
                                            embedding_size=embedding_size)
        input_states = old_model.get_sequence_output()
        input_states = tf.stop_gradient(input_states)

        teacher_output = self._build_teacher(input_states,
                                             inputs,
                                             is_training,
                                             embedding_size=embedding_size)
        # calculate the proposal distribution

        action_prob = teacher_output.action_probs  #pi(x_i)

        coin_toss = tf.random.uniform([])
        log_q, masked_inputs = self._sample_masking_subset(inputs, action_prob)
        if config.masking_strategy == pretrain_helpers.MIX_ADV_STRATEGY:
            random_masked_input = pretrain_helpers.mask(
                config, pretrain_data.features_to_inputs(features),
                config.mask_prob)
            B, L = modeling.get_shape_list(inputs.input_ids)
            N = config.max_predictions_per_seq
            strategy_prob = tf.random.uniform([B])
            strategy_prob = tf.expand_dims(
                tf.cast(tf.greater(strategy_prob, 0.5), tf.int32), 1)
            l_strategy_prob = tf.tile(strategy_prob, [1, L])
            n_strategy_prob = tf.tile(strategy_prob, [1, N])
            mix_input_ids = masked_inputs.input_ids * l_strategy_prob + random_masked_input.input_ids * (
                1 - l_strategy_prob)
            mix_masked_lm_positions = masked_inputs.masked_lm_positions * n_strategy_prob + random_masked_input.masked_lm_positions * (
                1 - n_strategy_prob)
            mix_masked_lm_ids = masked_inputs.masked_lm_ids * n_strategy_prob + random_masked_input.masked_lm_ids * (
                1 - n_strategy_prob)
            n_strategy_prob = tf.cast(n_strategy_prob, tf.float32)
            mix_masked_lm_weights = masked_inputs.masked_lm_weights * n_strategy_prob + random_masked_input.masked_lm_weights * (
                1 - n_strategy_prob)
            mix_masked_inputs = pretrain_data.get_updated_inputs(
                inputs,
                input_ids=tf.stop_gradient(mix_input_ids),
                masked_lm_positions=mix_masked_lm_positions,
                masked_lm_ids=mix_masked_lm_ids,
                masked_lm_weights=mix_masked_lm_weights,
                tag_ids=inputs.tag_ids)
            masked_inputs = mix_masked_inputs

        # BERT model
        model = self._build_transformer(masked_inputs,
                                        is_training,
                                        reuse=tf.AUTO_REUSE,
                                        embedding_size=embedding_size)
        mlm_output = self._get_masked_lm_output(masked_inputs, model)
        self.total_loss = mlm_output.loss

        # Teacher reward is the -log p(x_S|x;B)
        reward = tf.stop_gradient(
            tf.reduce_mean(mlm_output.per_example_loss, 1))
        self._baseline = tf.reduce_mean(reward, -1)
        self._std = tf.math.reduce_std(reward, -1)

        # Calculate teacher loss
        def compute_teacher_loss(log_q, reward, baseline, std):
            advantage = tf.abs((reward - baseline) / std)
            advantage = tf.stop_gradient(advantage)
            log_q = tf.Print(log_q, [log_q], "log_q: ")
            teacher_loss = tf.reduce_mean(-log_q * advantage)
            return teacher_loss

        teacher_loss = tf.cond(
            coin_toss < 0.1, lambda: compute_teacher_loss(
                log_q, reward, self._baseline, self._std),
            lambda: tf.constant(0.0))
        self.total_loss = mlm_output.loss + teacher_loss
        self.teacher_loss = teacher_loss
        self.mlm_loss = mlm_output.loss

        # Evaluation`
        eval_fn_inputs = {
            "input_ids": masked_inputs.input_ids,
            "masked_lm_preds": mlm_output.preds,
            "mlm_loss": mlm_output.per_example_loss,
            "masked_lm_ids": masked_inputs.masked_lm_ids,
            "masked_lm_weights": masked_inputs.masked_lm_weights,
            "input_mask": masked_inputs.input_mask
        }
        eval_fn_keys = eval_fn_inputs.keys()
        eval_fn_values = [eval_fn_inputs[k] for k in eval_fn_keys]
        """Computes the loss and accuracy of the model."""
        d = {k: arg for k, arg in zip(eval_fn_keys, eval_fn_values)}
        metrics = dict()
        metrics["masked_lm_accuracy"] = tf.metrics.accuracy(
            labels=tf.reshape(d["masked_lm_ids"], [-1]),
            predictions=tf.reshape(d["masked_lm_preds"], [-1]),
            weights=tf.reshape(d["masked_lm_weights"], [-1]))
        metrics["masked_lm_loss"] = tf.metrics.mean(
            values=tf.reshape(d["mlm_loss"], [-1]),
            weights=tf.reshape(d["masked_lm_weights"], [-1]))
        self.eval_metrics = metrics
示例#8
0
    def __init__(self, config: PretrainingConfig, features, is_training,
                 init_checkpoint):
        # Set up model config
        self._config = config
        self._bert_config = training_utils.get_bert_config(config)
        if config.debug:
            self._bert_config.num_hidden_layers = 3
            self._bert_config.hidden_size = 144
            self._bert_config.intermediate_size = 144 * 4
            self._bert_config.num_attention_heads = 4

        compute_type = modeling.infer_dtype(config.use_fp16)
        custom_getter = modeling.get_custom_getter(compute_type)

        with tf.variable_scope(tf.get_variable_scope(),
                               custom_getter=custom_getter):
            # Mask the input
            masked_inputs = pretrain_helpers.mask(
                config, pretrain_data.features_to_inputs(features),
                config.mask_prob)

            # Generator
            embedding_size = self._bert_config.hidden_size if config.embedding_size is None else config.embedding_size
            if config.uniform_generator:
                mlm_output = self._get_masked_lm_output(masked_inputs, None)
            elif config.electra_objective and config.untied_generator:
                generator = self._build_transformer(
                    name="generator",
                    inputs=masked_inputs,
                    is_training=is_training,
                    use_fp16=config.use_fp16,
                    bert_config=get_generator_config(config,
                                                     self._bert_config),
                    embedding_size=None
                    if config.untied_generator_embeddings else embedding_size,
                    untied_embeddings=config.untied_generator_embeddings)
                mlm_output = self._get_masked_lm_output(
                    masked_inputs, generator)
            else:
                generator = self._build_transformer(
                    name="electra",
                    inputs=masked_inputs,
                    is_training=is_training,
                    use_fp16=config.use_fp16,
                    embedding_size=embedding_size)
                mlm_output = self._get_masked_lm_output(
                    masked_inputs, generator)
            fake_data = self._get_fake_data(masked_inputs, mlm_output.logits)
            self.mlm_output = mlm_output
            self.total_loss = config.gen_weight * mlm_output.loss

            utils.log("Generator is built!")

            # Discriminator
            self.disc_output = None
            if config.electra_objective:
                discriminator = self._build_transformer(
                    name="electra",
                    inputs=fake_data.inputs,
                    is_training=is_training,
                    use_fp16=config.use_fp16,
                    embedding_size=embedding_size)
                utils.log("Discriminator is built!")
                self.disc_output = self._get_discriminator_output(
                    inputs=fake_data.inputs,
                    discriminator=discriminator,
                    labels=fake_data.is_fake_tokens)
                self.total_loss += config.disc_weight * self.disc_output.loss

        if init_checkpoint and hvd.rank() == 0:
            print("Loading checkpoint", init_checkpoint)
            assignment_map, _ = modeling.get_assignment_map_from_checkpoint(
                tvars=tf.trainable_variables(),
                init_checkpoint=init_checkpoint,
                prefix="")
            tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        # Evaluation
        eval_fn_inputs = {
            "input_ids": masked_inputs.input_ids,
            "masked_lm_preds": mlm_output.preds,
            "mlm_loss": mlm_output.per_example_loss,
            "masked_lm_ids": masked_inputs.masked_lm_ids,
            "masked_lm_weights": masked_inputs.masked_lm_weights,
            "input_mask": masked_inputs.input_mask
        }
        if config.electra_objective:
            eval_fn_inputs.update({
                "disc_loss":
                self.disc_output.per_example_loss,
                "disc_labels":
                self.disc_output.labels,
                "disc_probs":
                self.disc_output.probs,
                "disc_preds":
                self.disc_output.preds,
                "sampled_tokids":
                tf.argmax(fake_data.sampled_tokens, -1, output_type=tf.int32)
            })
        eval_fn_keys = eval_fn_inputs.keys()
        eval_fn_values = [eval_fn_inputs[k] for k in eval_fn_keys]

        def metric_fn(*args):
            """Computes the loss and accuracy of the model."""
            d = dict(zip(eval_fn_keys, args))
            metrics = dict()
            metrics["masked_lm_accuracy"] = tf.metrics.accuracy(
                labels=tf.reshape(d["masked_lm_ids"], [-1]),
                predictions=tf.reshape(d["masked_lm_preds"], [-1]),
                weights=tf.reshape(d["masked_lm_weights"], [-1]))
            metrics["masked_lm_loss"] = tf.metrics.mean(
                values=tf.reshape(d["mlm_loss"], [-1]),
                weights=tf.reshape(d["masked_lm_weights"], [-1]))
            if config.electra_objective:
                metrics["sampled_masked_lm_accuracy"] = tf.metrics.accuracy(
                    labels=tf.reshape(d["masked_lm_ids"], [-1]),
                    predictions=tf.reshape(d["sampled_tokids"], [-1]),
                    weights=tf.reshape(d["masked_lm_weights"], [-1]))
                if config.disc_weight > 0:
                    metrics["disc_loss"] = tf.metrics.mean(d["disc_loss"])
                    metrics["disc_auc"] = tf.metrics.auc(
                        d["disc_labels"] * d["input_mask"],
                        d["disc_probs"] * tf.cast(d["input_mask"], tf.float32))
                    metrics["disc_accuracy"] = tf.metrics.accuracy(
                        labels=d["disc_labels"],
                        predictions=d["disc_preds"],
                        weights=d["input_mask"])
                    metrics["disc_precision"] = tf.metrics.accuracy(
                        labels=d["disc_labels"],
                        predictions=d["disc_preds"],
                        weights=d["disc_preds"] * d["input_mask"])
                    metrics["disc_recall"] = tf.metrics.accuracy(
                        labels=d["disc_labels"],
                        predictions=d["disc_preds"],
                        weights=d["disc_labels"] * d["input_mask"])

            return metrics

        self.eval_metrics = (metric_fn, eval_fn_values)
def main():
    tf.set_random_seed(1234)
    np.random.seed(0)
    batch_size = 1
    tf_datatype = tf.int32
    np_datatype = np.int32
    iterations = 10

    features_ph = {}
    features_ph["input_ids"] = tf.placeholder(dtype=tf_datatype,
                                                shape=[batch_size, 128],
                                                name="input_ids")
    features_ph["input_mask"] = tf.placeholder(dtype=tf_datatype,
                                                shape=[batch_size, 128],
                                                name="input_mask")
    features_ph["token_type_ids"] = tf.placeholder(dtype=tf_datatype,
                                                    shape=[batch_size, 128],
                                                    name="token_type_ids")

    features_data = {}
    features_data["input_ids"] = np.random.rand(batch_size,
                                                128).astype(np_datatype)
    features_data["input_mask"] = np.random.rand(batch_size,
                                                    128).astype(np_datatype)
    features_data["token_type_ids"] = np.random.rand(
        batch_size, 128).astype(np_datatype)

    features_feed_dict = {
        features_ph[key]: features_data[key]
        for key in features_ph
    }

    finetuning_config = configure_finetuning.FinetuningConfig("ConvBert", "./")
    bert_config = training_utils.get_bert_config(finetuning_config)
    bert_model = modeling.BertModel(
        bert_config=bert_config,
        is_training=False,
        input_ids=features_ph["input_ids"],
        input_mask=features_ph["input_mask"],
        token_type_ids=features_ph["token_type_ids"])

    #outputs_names = "discriminator_predictions/Sigmoid:0,discriminator_predictions/truediv:0,discriminator_predictions/Cast_2:0,discriminator_predictions/truediv_1:0"
    graph_outputs = bert_model.get_sequence_output()
    outputs_names = graph_outputs.name
    print("graph output: ", graph_outputs)
    run_op_list = []
    outputs_names_with_port = outputs_names.split(",")
    outputs_names_without_port = [ name.split(":")[0] for name in outputs_names_with_port ]
    for index in range(len(outputs_names_without_port)):
        run_op_list.append(outputs_names_without_port[index])
    inputs_names_with_port = [features_ph[key].name for key in features_ph]

    # define saver
    #saver = tf.train.Saver(var_list=tf.trainable_variables())
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())

        for i in range(iterations):
            sess.run(run_op_list, feed_dict=features_feed_dict)
        tf_time_sum = 0
        a = datetime.now()
        for i in range(iterations):
            tf_result = sess.run(run_op_list, feed_dict=features_feed_dict)
        b = datetime.now()
        tf_time_sum = (b - a).total_seconds()
        tf_time = "[INFO] TF  execution time: " + str(
            tf_time_sum * 1000 / iterations) + " ms"
        # tf_result = tf_result.flatten()

        frozen_graph = tf.graph_util.convert_variables_to_constants(
            sess, sess.graph_def, outputs_names_without_port)
        # frozen_graph = tf.graph_util.remove_training_nodes(frozen_graph)
        # save frozen model
        with open("ConvBert.pb", "wb") as ofile:
            ofile.write(frozen_graph.SerializeToString())

    # tf.reset_default_graph()
    # tf.import_graph_def(frozen_graph, name='')

    # #with tf.Session(config=config) as sess:
    # sess = tf.Session(config=config)
    # graph_def = tf_optimize(inputs_names_with_port, outputs_names_without_port,
    #                         sess.graph_def, True)

    # with open("ConvBert_optimized_model.pb", "wb") as ofile:
    #     ofile.write(graph_def.SerializeToString())

    onnx_model_file = "ConvBert.onnx"
    command = "python3 -m tf2onnx.convert --input ConvBert.pb --output %s --fold_const --opset 12 --verbose" % onnx_model_file
    command += " --inputs "
    for name in inputs_names_with_port:
        command += "%s," % name
    command = command[:-1] + " --outputs "
    for name in outputs_names_with_port:
        command += "%s," % name
    command = command[:-1]
    os.system(command)
    print(command)
    #exit(0)

    command = "trtexec - -onnx = ConvBert.onnx - -verbose"
    os.system(command)
    print(command)
示例#10
0
    def __init__(self, config: configure_pretraining.PretrainingConfig,
                 features, is_training):
        # Set up model config
        self._config = config
        self._bert_config = training_utils.get_bert_config(config)
        if config.debug:
            self._bert_config.num_hidden_layers = 3
            self._bert_config.hidden_size = 144
            self._bert_config.intermediate_size = 144 * 4
            self._bert_config.num_attention_heads = 4

        # Mask the input
        # masked_inputs = pretrain_helpers.mask(
        #     config, pretrain_data.features_to_inputs(features), config.mask_prob)

        masked_inputs = Inputs(
            tf.constant(
                [[
                    101,
                    2151,
                    11385,
                    2052,
                    2031,
                    103,
                    4235,
                    15484,
                    2011,
                    2796,
                    4153,
                    14731,
                    1999,
                    103,
                    2733,
                    2144,
                    1996,
                    4946,
                    5419,
                    1012,
                    1523,
                    103,
                    2031,
                    2116,
                    103,
                    3001,
                    4082,
                    1999,
                    1996,
                    103,
                    1010,
                    2021,
                    2498,
                    2001,
                    3856,
                    2039,
                    1010,
                    103,
                    4373,
                    5902,
                    19219,
                    11961,
                    17357,
                    103,
                    1010,
                    2708,
                    1997,
                    3095,
                    1997,
                    2634,
                    1521,
                    103,
                    1998,
                    23093,
                    2015,
                    1998,
                    19332,
                    8237,
                    3094,
                    1010,
                    2409,
                    26665,
                    1012,
                    1523,
                    2009,
                    2003,
                    2825,
                    2008,
                    1996,
                    103,
                    7217,
                    2015,
                    103,
                    7237,
                    2125,
                    2004,
                    2057,
                    5452,
                    2006,
                    2019,
                    2004,
                    1011,
                    3223,
                    3978,
                    1012,
                    2061,
                    3383,
                    3905,
                    103,
                    2015,
                    2020,
                    4082,
                    1010,
                    2029,
                    2089,
                    2025,
                    2031,
                    1996,
                    3223,
                    2846,
                    2000,
                    11487,
                    1037,
                    3462,
                    2012,
                    2019,
                    7998,
                    1997,
                    103,
                    1010,
                    2199,
                    2519,
                    1012,
                    1524,
                    102,
                    103,
                    2796,
                    4153,
                    2038,
                    2019,
                    2779,
                    5995,
                    1997,
                    2062,
                    2084,
                    2260,
                    1010,
                    102,
                ]],
                dtype=tf.int32,
            ),
            tf.constant(
                [[
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                ]],
                dtype=tf.int32,
            ),
            tf.constant(
                [[
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    0,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                    1,
                ]],
                dtype=tf.int32,
            ),
            tf.constant([[
                37, 69, 5, 106, 13, 115, 51, 43, 72, 116, 21, 21, 88, 24, 13,
                29, 93, 69, 108
            ]],
                        dtype=tf.int32),
            tf.constant(
                [[
                    1524,
                    2510,
                    2042,
                    7998,
                    1996,
                    1996,
                    1055,
                    4886,
                    2020,
                    2796,
                    2057,
                    2057,
                    7217,
                    7217,
                    1996,
                    2181,
                    2029,
                    2510,
                    3486,
                ]],
                dtype=tf.int32,
            ),
            tf.constant(
                [[
                    1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
                    1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0
                ]],
                dtype=tf.float32,
            ),
        )

        self.masked_inputs = masked_inputs

        # Generator
        embedding_size = self._bert_config.hidden_size if config.embedding_size is None else config.embedding_size
        if config.uniform_generator:
            mlm_output = self._get_masked_lm_output(masked_inputs, None)
        elif config.electra_objective and config.untied_generator:
            generator = self._build_transformer(
                masked_inputs,
                is_training,
                bert_config=get_generator_config(config, self._bert_config),
                embedding_size=(None if config.untied_generator_embeddings else
                                embedding_size),
                untied_embeddings=config.untied_generator_embeddings,
                name="generator",
            )
            mlm_output, self.relevant_hidden = self._get_masked_lm_output(
                masked_inputs, generator)
        else:
            generator = self._build_transformer(masked_inputs,
                                                is_training,
                                                embedding_size=embedding_size)
            mlm_output = self._get_masked_lm_output(masked_inputs, generator)
        fake_data = self._get_fake_data(masked_inputs, mlm_output.logits)

        FakedData = collections.namedtuple(
            "FakedData", ["inputs", "is_fake_tokens"])  # , "sampled_tokens"])
        fake_data = FakedData(
            Inputs(
                tf.constant([[
                    101, 2151, 11385, 2052, 2031, 20464, 4235, 15484, 2011,
                    2796, 4153, 14731, 1999, 4952, 2733, 2144, 1996, 4946,
                    5419, 1012, 1523, 8045, 2031, 2116, 7367, 3001, 4082, 1999,
                    1996, 22366, 1010, 2021, 2498, 2001, 3856, 2039, 1010,
                    1422, 4373, 5902, 19219, 11961, 17357, 16374, 1010, 2708,
                    1997, 3095, 1997, 2634, 1521, 21904, 1998, 23093, 2015,
                    1998, 19332, 8237, 3094, 1010, 2409, 26665, 1012, 1523,
                    2009, 2003, 2825, 2008, 1996, 15778, 7217, 2015, 12767,
                    7237, 2125, 2004, 2057, 5452, 2006, 2019, 2004, 1011, 3223,
                    3978, 1012, 2061, 3383, 3905, 4852, 2015, 2020, 4082, 1010,
                    20229, 2089, 2025, 2031, 1996, 3223, 2846, 2000, 11487,
                    1037, 3462, 2012, 2019, 21157, 1997, 431, 1010, 2199, 2519,
                    1012, 1524, 102, 16353, 7069, 4153, 2038, 2019, 2779, 5995,
                    1997, 2062, 2084, 2260, 1010, 102
                ]],
                            dtype=tf.int32),
                tf.constant([[
                    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                    1, 1, 1, 1, 1, 1, 1, 1
                ]],
                            dtype=tf.int32),
                tf.constant([[
                    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,
                    1, 1, 1, 1, 1, 1, 1, 1
                ]],
                            dtype=tf.int32),
                tf.constant([[
                    37, 69, 5, 106, 13, 115, 51, 43, 72, 116, 21, 21, 88, 24,
                    13, 29, 93, 69, 108
                ]],
                            dtype=tf.int32),
                tf.constant([[
                    1524, 2510, 2042, 7998, 1996, 1996, 1055, 4886, 2020, 2796,
                    2057, 2057, 7217, 7217, 1996, 2181, 2029, 2510, 3486
                ]],
                            dtype=tf.int32),
                tf.constant([[
                    1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                    1., 1., 1., 1.
                ]])),
            tf.constant([[
                0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
                1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
                0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                0, 0
            ]],
                        dtype=tf.int32))
        print("FAKE DATA", fake_data)

        self.token_embeddings = tf.identity(generator.token_embeddings)
        self.all_embeddings = tf.identity(generator.all_embeddings)
        # self.attention_mask = tf.identity(generator.attention_mask)

        self.generator_embedding_output = tf.identity(
            generator.get_embedding_output())
        self.generator_all_encoder_layers = tf.identity(
            generator.get_all_encoder_layers())
        self.sequence_output = tf.identity(generator.get_sequence_output())
        self.pooled_output = tf.identity(generator.get_pooled_output())

        self.mlm_output = mlm_output
        self.total_loss = config.gen_weight * mlm_output.loss

        # Discriminator
        disc_output = None
        if config.electra_objective:
            discriminator = self._build_transformer(
                fake_data.inputs,
                is_training,
                reuse=not config.untied_generator,
                embedding_size=embedding_size)
            disc_output = self._get_discriminator_output(
                fake_data.inputs, discriminator, fake_data.is_fake_tokens)
            self.total_loss += config.disc_weight * disc_output.loss

        self.discriminator_embedding_output = tf.identity(
            discriminator.get_embedding_output())
        self.discriminator_all_encoder_layers = tf.identity(
            discriminator.get_all_encoder_layers())
        self.discriminator_output = disc_output
示例#11
0
    def __init__(self, config: configure_pretraining.PretrainingConfig,
                 features, is_training):
        # Set up model config
        self._config = config
        self._bert_config = training_utils.get_bert_config(config)
        if config.debug:
            self._bert_config.num_hidden_layers = 3
            self._bert_config.hidden_size = 144
            self._bert_config.intermediate_size = 144 * 4
            self._bert_config.num_attention_heads = 4

        # Mask the input
        unmasked_inputs = pretrain_data.features_to_inputs(features)
        masked_inputs = pretrain_helpers.mask(config, unmasked_inputs,
                                              config.mask_prob)

        # Generator
        embedding_size = (self._bert_config.hidden_size
                          if config.embedding_size is None else
                          config.embedding_size)
        cloze_output = None
        if config.uniform_generator:
            # simple generator sampling fakes uniformly at random
            mlm_output = self._get_masked_lm_output(masked_inputs, None)
        elif ((config.electra_objective or config.electric_objective)
              and config.untied_generator):
            generator_config = get_generator_config(config, self._bert_config)
            if config.two_tower_generator:
                # two-tower cloze model generator used for electric
                generator = TwoTowerClozeTransformer(config, generator_config,
                                                     unmasked_inputs,
                                                     is_training,
                                                     embedding_size)
                cloze_output = self._get_cloze_outputs(unmasked_inputs,
                                                       generator)
                mlm_output = get_softmax_output(
                    pretrain_helpers.gather_positions(
                        cloze_output.logits,
                        masked_inputs.masked_lm_positions),
                    masked_inputs.masked_lm_ids,
                    masked_inputs.masked_lm_weights,
                    self._bert_config.vocab_size)
            else:
                # small masked language model generator
                generator = build_transformer(
                    config,
                    masked_inputs,
                    is_training,
                    generator_config,
                    embedding_size=(None if config.untied_generator_embeddings
                                    else embedding_size),
                    untied_embeddings=config.untied_generator_embeddings,
                    scope="generator")
                mlm_output = self._get_masked_lm_output(
                    masked_inputs, generator)
        else:
            # full-sized masked language model generator if using BERT objective or if
            # the generator and discriminator have tied weights
            generator = build_transformer(config,
                                          masked_inputs,
                                          is_training,
                                          self._bert_config,
                                          embedding_size=embedding_size)
            mlm_output = self._get_masked_lm_output(masked_inputs, generator)
        fake_data = self._get_fake_data(masked_inputs, mlm_output.logits)
        self.mlm_output = mlm_output
        self.total_loss = config.gen_weight * (cloze_output.loss
                                               if config.two_tower_generator
                                               else mlm_output.loss)

        # Discriminator
        disc_output = None
        if config.electra_objective or config.electric_objective:
            discriminator = build_transformer(
                config,
                fake_data.inputs,
                is_training,
                self._bert_config,
                reuse=not config.untied_generator,
                embedding_size=embedding_size)
            disc_output = self._get_discriminator_output(
                fake_data.inputs, discriminator, fake_data.is_fake_tokens,
                cloze_output)
            self.total_loss += config.disc_weight * disc_output.loss

        # Evaluation
        eval_fn_inputs = {
            "input_ids": masked_inputs.input_ids,
            "masked_lm_preds": mlm_output.preds,
            "mlm_loss": mlm_output.per_example_loss,
            "masked_lm_ids": masked_inputs.masked_lm_ids,
            "masked_lm_weights": masked_inputs.masked_lm_weights,
            "input_mask": masked_inputs.input_mask
        }
        if config.electra_objective or config.electric_objective:
            eval_fn_inputs.update({
                "disc_loss":
                disc_output.per_example_loss,
                "disc_labels":
                disc_output.labels,
                "disc_probs":
                disc_output.probs,
                "disc_preds":
                disc_output.preds,
                "sampled_tokids":
                tf.argmax(fake_data.sampled_tokens, -1, output_type=tf.int32)
            })
        eval_fn_keys = eval_fn_inputs.keys()
        eval_fn_values = [eval_fn_inputs[k] for k in eval_fn_keys]

        def metric_fn(*args):
            """Computes the loss and accuracy of the model."""
            d = {k: arg for k, arg in zip(eval_fn_keys, args)}
            metrics = dict()
            metrics["masked_lm_accuracy"] = tf.metrics.accuracy(
                labels=tf.reshape(d["masked_lm_ids"], [-1]),
                predictions=tf.reshape(d["masked_lm_preds"], [-1]),
                weights=tf.reshape(d["masked_lm_weights"], [-1]))
            metrics["masked_lm_loss"] = tf.metrics.mean(
                values=tf.reshape(d["mlm_loss"], [-1]),
                weights=tf.reshape(d["masked_lm_weights"], [-1]))
            if config.electra_objective or config.electric_objective:
                metrics["sampled_masked_lm_accuracy"] = tf.metrics.accuracy(
                    labels=tf.reshape(d["masked_lm_ids"], [-1]),
                    predictions=tf.reshape(d["sampled_tokids"], [-1]),
                    weights=tf.reshape(d["masked_lm_weights"], [-1]))
                if config.disc_weight > 0:
                    metrics["disc_loss"] = tf.metrics.mean(d["disc_loss"])
                    metrics["disc_auc"] = tf.metrics.auc(
                        d["disc_labels"] * d["input_mask"],
                        d["disc_probs"] * tf.cast(d["input_mask"], tf.float32))
                    metrics["disc_accuracy"] = tf.metrics.accuracy(
                        labels=d["disc_labels"],
                        predictions=d["disc_preds"],
                        weights=d["input_mask"])
                    metrics["disc_precision"] = tf.metrics.accuracy(
                        labels=d["disc_labels"],
                        predictions=d["disc_preds"],
                        weights=d["disc_preds"] * d["input_mask"])
                    metrics["disc_recall"] = tf.metrics.accuracy(
                        labels=d["disc_labels"],
                        predictions=d["disc_preds"],
                        weights=d["disc_labels"] * d["input_mask"])
            return metrics

        self.eval_metrics = (metric_fn, eval_fn_values)