예제 #1
0
def main(_):
    mode = tf.estimator.ModeKeys.TRAIN
    use_one_hot_embeddings = FLAGS.use_tpu

    tf.logging.set_verbosity(tf.logging.INFO)

    if not FLAGS.do_train and not FLAGS.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    tf.gfile.MakeDirs(FLAGS.output_dir)

    input_files = []
    for input_pattern in FLAGS.input_file.split(","):
        input_files.extend(tf.gfile.Glob(input_pattern))

    tf.logging.info("*** Input Files ***")
    for input_file in input_files:
        tf.logging.info("  %s" % input_file)

    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=is_per_host))

    # model_fn = mode_hot_embeddings=FLAGS.usl_fn_builder(
    #   #     bert_config=bert_config,
    #   #     init_checkpoint=FLAGS.init_checkpoint,
    #   #     learning_rate=FLAGS.learning_rate,
    #   #     num_train_steps=FLAGS.num_train_steps,
    #   #     num_warmup_steps=FLAGS.num_warmup_steps,
    #   #     use_tpu=FLAGS.use_tpu,
    #   #     use_onee_tpu)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    # estimator = tf.contrib.tpu.TPUEstimator(
    #     use_tpu=FLAGS.use_tpu,
    #     model_fn=model_fn,
    #     config=run_config,
    #     train_batch_size=FLAGS.train_batch_size,
    #     eval_batch_size=FLAGS.eval_batch_size)

    if FLAGS.do_train:
        tf.logging.info("***** Running training *****")
        tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
        n_gpus = 4
        batch_size = FLAGS.train_batch_size
        d = input_fn(input_files, FLAGS.train_batch_size * n_gpus,
                     FLAGS.max_seq_length, FLAGS.max_predictions_per_seq, True)
        features, iterator = parse_input_fn_result(d)
        # train_input_fn = input_fn_builder(
        #     input_files=input_files,
        #     max_seq_length=FLAGS.max_seq_length,
        #     max_predictions_per_seq=FLAGS.max_predictions_per_seq,
        #     is_training=True)
        # estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)

        input_ids_list = tf.split(features["input_ids"], n_gpus, axis=0)
        input_mask_list = tf.split(features["input_mask"], n_gpus, axis=0)
        segment_ids_list = tf.split(features["segment_ids"], n_gpus, axis=0)
        masked_lm_positions_list = tf.split(features["masked_lm_positions"],
                                            n_gpus,
                                            axis=0)
        masked_lm_ids_list = tf.split(features["masked_lm_ids"],
                                      n_gpus,
                                      axis=0)
        masked_lm_weights_list = tf.split(features["masked_lm_weights"],
                                          n_gpus,
                                          axis=0)
        next_sentence_labels_list = tf.split(features["next_sentence_labels"],
                                             n_gpus,
                                             axis=0)

        # multi-gpu train
        with tf.device('/cpu:0'):
            optimizer = optimization_gpu.create_optimizer(
                None, FLAGS.learning_rate, FLAGS.num_train_steps,
                FLAGS.num_warmup_steps, False)

            global_step = tf.train.get_or_create_global_step()
            # calculate the gradients on each GPU
            tower_grads = []
            models = []
            train_perplexity = tf.get_variable(
                'train_perplexity', [],
                initializer=tf.constant_initializer(0.0),
                trainable=False)
            for k in range(n_gpus):
                with tf.device('/gpu:%d' % k):
                    with tf.variable_scope('lm', reuse=k > 0):
                        # calculate the loss for one model replica and get
                        #   lstm states

                        input_ids = input_ids_list[k]
                        input_mask = input_mask_list[k]
                        segment_ids = segment_ids_list[k]
                        masked_lm_positions = masked_lm_positions_list[k]
                        masked_lm_ids = masked_lm_ids_list[k]
                        masked_lm_weights = masked_lm_weights_list[k]
                        next_sentence_labels = next_sentence_labels_list[k]

                        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

                        model = modeling.BertModel(
                            config=bert_config,
                            is_training=is_training,
                            input_ids=input_ids,
                            input_mask=input_mask,
                            token_type_ids=segment_ids,
                            use_one_hot_embeddings=use_one_hot_embeddings)

                        (masked_lm_loss, masked_lm_example_loss,
                         masked_lm_log_probs) = get_masked_lm_output(
                             bert_config, model.get_sequence_output(),
                             model.get_embedding_table(), masked_lm_positions,
                             masked_lm_ids, masked_lm_weights)

                        (next_sentence_loss, next_sentence_example_loss,
                         next_sentence_log_probs) = get_next_sentence_output(
                             bert_config, model.get_pooled_output(),
                             next_sentence_labels)

                        total_loss = masked_lm_loss + next_sentence_loss

                        loss = total_loss
                        models.append(model)
                        # get gradients
                        grads = optimizer.compute_gradients(
                            loss,
                            aggregation_method=tf.AggregationMethod.
                            EXPERIMENTAL_TREE,
                        )
                        tower_grads.append(grads)
                        # keep track of loss across all GPUs
                        train_perplexity += loss

            average_grads = average_gradients(tower_grads, None, None)
            average_grads, norm_summary_ops = clip_grads(
                average_grads, 10.0, True, global_step)
            train_perplexity = tf.exp(train_perplexity / n_gpus)
            train_op = optimizer.apply_gradients(average_grads,
                                                 global_step=global_step)
            init = tf.global_variables_initializer()
            saver = tf.train.Saver(tf.global_variables(), max_to_keep=2)
        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            sess.run(init)
            sess.run(iterator.initializer)
            sum = 0
            count = 0
            t0 = time.time()
            while True:

                _, train_perplexity_ = sess.run([train_op, train_perplexity])

                sum += train_perplexity_
                count += 1
                if count % 100 == 0:
                    print("------------")
                    print(time.time() - t0, " ms")
                    t0 = time.time()
                    print("loss ", sum / count)
                    sum = 0

                if count % 10000 == 0:
                    checkpoint_path = os.path.join(FLAGS.output_dir,
                                                   'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=global_step)

    if FLAGS.do_eval:
        tf.logging.info("***** Running evaluation *****")
        tf.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

        eval_input_fn = input_fn_builder(
            input_files=input_files,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            is_training=False)

        result = estimator.evaluate(input_fn=eval_input_fn,
                                    steps=FLAGS.max_eval_steps)

        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        with tf.gfile.GFile(output_eval_file, "w") as writer:
            tf.logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                tf.logging.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
예제 #2
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]
        is_real_example = None
        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)
        else:
            is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        (total_loss, per_example_loss, logits,
         probabilities) = create_model(bert_config, is_training, input_ids,
                                       input_mask, segment_ids, label_ids,
                                       num_labels, use_one_hot_embeddings)

        tvars = tf.trainable_variables()
        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = modeling.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:

            train_op = optimization_gpu.create_optimizer(
                total_loss, learning_rate, num_train_steps, num_warmup_steps)

            output_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                     loss=total_loss,
                                                     train_op=train_op,
                                                     scaffold=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, label_ids, logits,
                          is_real_example):
                predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
                accuracy = tf.metrics.accuracy(labels=label_ids,
                                               predictions=predictions,
                                               weights=is_real_example)
                loss = tf.metrics.mean(values=per_example_loss,
                                       weights=is_real_example)
                return {
                    "eval_accuracy": accuracy,
                    "eval_loss": loss,
                }

            # eval_metrics = (metric_fn,
            #                 [per_example_loss, label_ids, logits, is_real_example])
            eval_metrics = metric_fn(per_example_loss, label_ids, logits,
                                     is_real_example)

            # output_spec = tf.contrib.tpu.TPUEstimatorSpec(
            #     mode=mode,
            #     loss=total_loss,
            #     eval_metrics=eval_metrics,
            #     scaffold_fn=scaffold_fn)
            output_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metric_ops=eval_metrics,
                scaffold=scaffold_fn)
        else:
            output_spec = tf.estimator.EtimatorSpec(
                mode=mode,
                predictions={"probabilities": probabilities},
                scaffold_fn=scaffold_fn)
        return output_spec
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_lm_weights = features["masked_lm_weights"]
        next_sentence_labels = features["next_sentence_labels"]

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        model = modeling.BertModel(
            config=bert_config,
            is_training=is_training,
            input_ids=input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=use_one_hot_embeddings)

        (masked_lm_loss, masked_lm_example_loss,
         masked_lm_log_probs) = get_masked_lm_output(
             bert_config, model.get_sequence_output(),
             model.get_embedding_table(), masked_lm_positions, masked_lm_ids,
             masked_lm_weights)

        (next_sentence_loss, next_sentence_example_loss,
         next_sentence_log_probs) = get_next_sentence_output(
             bert_config, model.get_pooled_output(), next_sentence_labels)

        total_loss = masked_lm_loss + next_sentence_loss

        tvars = tf.trainable_variables()

        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = modeling.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer(total_loss, learning_rate,
                                                     num_train_steps,
                                                     num_warmup_steps, False)
            output_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                     loss=total_loss,
                                                     train_op=train_op,
                                                     scaffold=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(masked_lm_example_loss, masked_lm_log_probs,
                          masked_lm_ids, masked_lm_weights,
                          next_sentence_example_loss, next_sentence_log_probs,
                          next_sentence_labels):
                """Computes the loss and accuracy of the model."""
                masked_lm_log_probs = tf.reshape(
                    masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]])
                masked_lm_predictions = tf.argmax(masked_lm_log_probs,
                                                  axis=-1,
                                                  output_type=tf.int32)
                masked_lm_example_loss = tf.reshape(masked_lm_example_loss,
                                                    [-1])
                masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
                masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
                masked_lm_accuracy = tf.metrics.accuracy(
                    labels=masked_lm_ids,
                    predictions=masked_lm_predictions,
                    weights=masked_lm_weights)
                masked_lm_mean_loss = tf.metrics.mean(
                    values=masked_lm_example_loss, weights=masked_lm_weights)

                next_sentence_log_probs = tf.reshape(
                    next_sentence_log_probs,
                    [-1, next_sentence_log_probs.shape[-1]])
                next_sentence_predictions = tf.argmax(next_sentence_log_probs,
                                                      axis=-1,
                                                      output_type=tf.int32)
                next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
                next_sentence_accuracy = tf.metrics.accuracy(
                    labels=next_sentence_labels,
                    predictions=next_sentence_predictions)
                next_sentence_mean_loss = tf.metrics.mean(
                    values=next_sentence_example_loss)

                return {
                    "masked_lm_accuracy": masked_lm_accuracy,
                    "masked_lm_loss": masked_lm_mean_loss,
                    "next_sentence_accuracy": next_sentence_accuracy,
                    "next_sentence_loss": next_sentence_mean_loss,
                }

            eval_metrics = (metric_fn, [
                masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
                masked_lm_weights, next_sentence_example_loss,
                next_sentence_log_probs, next_sentence_labels
            ])
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        else:
            raise ValueError("Only TRAIN and EVAL modes are supported: %s" %
                             (mode))

        return output_spec
예제 #4
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        input_ids = features["input_ids"]
        input_origin_ids = features["lm_input_ids"]
        input_target_ids = features["lm_target_ids"]
        lm_weights = get_lm_weights(
            input_target_ids, word2id,
            ['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]'])

        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_type_ids = features["input_type_ids"]
        #masked_lm_weights = features["masked_lm_weights"]
        masked_lm_weights = get_lm_weights(masked_lm_ids, word2id,
                                           ['[PAD]', '[UNK]'])
        # next_sentence_labels = features["next_sentence_labels"]

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        model = modeling.BertModel(
            config=bert_config,
            is_training=is_training,
            input_ids=input_ids,
            origin_input_ids=input_origin_ids,
            type_ids=masked_type_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=use_one_hot_embeddings,
            LM=FLAGS.LM)

        (masked_lm_loss, masked_lm_type_loss, masked_lm_example_loss,
         masked_lm_log_probs) = get_masked_lm_output(
             bert_config, model.get_sequence_output(),
             model.get_embedding_table(), model.get_type_embedding_table(),
             masked_lm_positions, masked_lm_ids, masked_type_ids,
             masked_lm_weights)

        (lm_loss, lm_example_loss,
         lm_log_probs) = get_lm_output(bert_config,
                                       model.get_origin_sequence_output(),
                                       model.get_embedding_table(),
                                       input_target_ids, lm_weights)

        # (next_sentence_loss, next_sentence_example_loss,
        #  next_sentence_log_probs) = get_next_sentence_output(
        #     bert_config, model.get_pooled_output(), next_sentence_labels)

        total_loss = masked_lm_loss + masked_lm_type_loss + lm_loss

        tvars = tf.trainable_variables()

        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = modeling.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        total_parameters = 0
        for variable in tf.trainable_variables():
            # shape is an array of tf.Dimension
            shape = variable.get_shape()
            #print(shape)
            #print(len(shape))
            variable_parameters = 1
            for dim in shape:
                #print(dim)
                variable_parameters *= dim.value
            #print(variable_parameters)
            total_parameters += variable_parameters
        print('total parameters: ', total_parameters)
        with open('parameters.txt', 'w') as f:
            f.write(str(total_parameters))

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization_gpu.create_optimizer(
                total_loss, learning_rate, num_train_steps, num_warmup_steps)
            output_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                     loss=total_loss,
                                                     train_op=train_op,
                                                     scaffold=scaffold_fn)

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(masked_lm_example_loss, masked_lm_log_probs,
                          masked_lm_ids, masked_lm_weights, lm_example_loss,
                          lm_log_probs, input_target_ids, lm_weights):
                """Computes the loss and accuracy of the model."""
                masked_lm_log_probs = tf.reshape(
                    masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]])
                masked_lm_predictions = tf.argmax(masked_lm_log_probs,
                                                  axis=-1,
                                                  output_type=tf.int32)
                masked_lm_example_loss = tf.reshape(masked_lm_example_loss,
                                                    [-1])
                masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
                masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
                masked_lm_accuracy = tf.metrics.accuracy(
                    labels=masked_lm_ids,
                    predictions=masked_lm_predictions,
                    weights=masked_lm_weights)
                masked_lm_mean_loss = tf.metrics.mean(
                    values=masked_lm_example_loss, weights=masked_lm_weights)

                lm_log_probs = tf.reshape(lm_log_probs,
                                          [-1, lm_log_probs.shape[-1]])
                lm_predictions = tf.argmax(lm_log_probs,
                                           axis=-1,
                                           output_type=tf.int32)
                lm_example_loss = tf.reshape(lm_example_loss, [-1])
                lm_target_ids = tf.reshape(input_target_ids, [-1])
                lm_weights = tf.reshape(lm_weights, [-1])
                lm_accuracy = tf.metrics.accuracy(labels=lm_target_ids,
                                                  predictions=lm_predictions,
                                                  weights=lm_weights)
                lm_mean_loss = tf.metrics.mean(values=lm_example_loss,
                                               weights=lm_weights)

                return {
                    "masked_lm_accuracy": masked_lm_accuracy,
                    "masked_lm_loss": masked_lm_mean_loss,
                    "lm_accuracy": lm_accuracy,
                    "lm_loss": lm_mean_loss,
                }

            wrong_mask_lm_label = tf.constant(value=-1,
                                              dtype=tf.int32,
                                              shape=masked_lm_ids.shape)
            unk_id = word2id['[UNK]']
            unk_tf = tf.constant(value=unk_id,
                                 dtype=tf.int32,
                                 shape=masked_lm_ids.shape)
            condition_mask_lm_tf = tf.equal(masked_lm_ids, unk_tf)
            new_mask_lm_labels = tf.where(condition_mask_lm_tf,
                                          wrong_mask_lm_label, masked_lm_ids)

            wrong_lm_label = tf.constant(value=-1,
                                         dtype=tf.int32,
                                         shape=input_target_ids.shape)
            unk_tf = tf.constant(value=unk_id,
                                 dtype=tf.int32,
                                 shape=input_target_ids.shape)
            condition_lm_tf = tf.equal(input_target_ids, unk_tf)
            new_lm_labels = tf.where(condition_lm_tf, wrong_lm_label,
                                     input_target_ids)

            eval_metrics = metric_fn(
                masked_lm_example_loss,
                masked_lm_log_probs,
                new_mask_lm_labels,
                masked_lm_weights,
                lm_example_loss,
                lm_log_probs,
                new_lm_labels,
                lm_weights,
            )

            output_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metric_ops=eval_metrics,
                scaffold=scaffold_fn)

        elif mode == tf.estimator.ModeKeys.PREDICT:
            lm_log_probs = tf.reshape(lm_log_probs,
                                      [-1, lm_log_probs.shape[-1]])
            lm_predictions = tf.argmax(lm_log_probs,
                                       axis=-1,
                                       output_type=tf.int32)
            lm_target_ids = tf.reshape(input_target_ids, [-1])

            masked_lm_log_probs = tf.reshape(
                masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]])
            masked_lm_predictions = tf.argmax(masked_lm_log_probs,
                                              axis=-1,
                                              output_type=tf.int32)

            masked_lm_ids = tf.reshape(masked_lm_ids, [-1])

            lm_accuracy = tf.metrics.accuracy(labels=lm_target_ids,
                                              predictions=lm_predictions,
                                              weights=lm_weights)
            lm_mean_loss = tf.metrics.mean(values=lm_example_loss,
                                           weights=lm_weights)

            # lm predictions
            predictions = {
                "input_ids": tf.reshape(input_origin_ids, [-1]),
                "lm_pre": lm_predictions,
                "lm_tar": lm_target_ids,
            }

            # # id predictions
            # predictions = {
            #     "masked_pre": masked_lm_predictions,
            #     "masked_tar": masked_lm_ids,
            # }

            output_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                     predictions=predictions)

        else:
            raise ValueError("Only TRAIN and EVAL modes are supported: %s" %
                             (mode))

        return output_spec