Пример #1
0
    def model_fn(features, labels, mode, params):
        """doc."""
        #### Training or Evaluation
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        assert is_training

        #### Retrieve `mems` from `params["cache"]`
        mems = {}
        idx = 0
        if FLAGS.mem_len > 0:
            #mems["mems"] = params["cache"]
            mems["mems"] = cache_fn()
        #### Get loss from inputs
        total_loss, new_mems, monitor_dict = function_builder.get_loss(
            FLAGS, features, labels, mems, is_training)

        #### Turn `new_mems` into `new_cache`
        new_cache = []
        if FLAGS.mem_len > 0:
            new_cache += new_mems["mems"]

        #### Check model parameters
        num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
        tf.logging.info("#params: {}".format(num_params))

        #### Configuring the optimizer
        train_op, learning_rate, gnorm = model_utils.get_train_op(
            FLAGS, total_loss)
        monitor_dict["lr"] = learning_rate
        monitor_dict["gnorm"] = gnorm

        #### Customized initial checkpoint
        scaffold_fn = model_utils.init_from_checkpoint(FLAGS, global_vars=True)

        #### Creating host calls
        host_call = function_builder.construct_scalar_host_call(
            monitor_dict=monitor_dict,
            model_dir=FLAGS.model_dir,
            prefix="train/",
            reduce_fn=tf.reduce_mean)

        #### Constucting training TPUEstimatorSpec with new cache.
        if FLAGS.use_tpu:
            train_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                host_call=host_call,
                scaffold_fn=scaffold_fn)
        else:
            train_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                    loss=total_loss,
                                                    train_op=train_op)

        train_spec.cache = new_cache

        return train_spec
Пример #2
0
    def model_fn(features, labels, mode, params):
        #### Training or Evaluation
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        # Get loss from inputs
        if FLAGS.is_regression:
            (total_loss, per_example_loss,
             logits) = function_builder.get_regression_loss(
                 FLAGS, features, is_training)
        else:
            (total_loss, per_example_loss,
             logits) = function_builder.get_classification_loss(
                 FLAGS, features, n_class, is_training)

        # Check model parameters
        num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
        tf.logging.info('#params: {}'.format(num_params))

        # load pretrained models
        scaffold_fn = model_utils.init_from_checkpoint(FLAGS)

        # Evaluation mode
        if mode == tf.estimator.ModeKeys.EVAL:
            assert FLAGS.num_hosts == 1

            def metric_fn(per_example_loss, label_ids, logits,
                          is_real_example):
                predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
                eval_input_dict = {
                    'labels': label_ids,
                    'predictions': predictions,
                    'weights': is_real_example
                }
                accuracy = tf.metrics.accuracy(**eval_input_dict)

                loss = tf.metrics.mean(values=per_example_loss,
                                       weights=is_real_example)
                return {'eval_accuracy': accuracy, 'eval_loss': loss}

            def regression_metric_fn(per_example_loss, label_ids, logits,
                                     is_real_example):
                loss = tf.metrics.mean(values=per_example_loss,
                                       weights=is_real_example)
                pearsonr = tf.contrib.metrics.streaming_pearson_correlation(
                    logits, label_ids, weights=is_real_example)
                return {'eval_loss': loss, 'eval_pearsonr': pearsonr}

            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)

            # Constucting evaluation TPUEstimatorSpec with new cache.
            label_ids = tf.reshape(features['label_ids'], [-1])

            if FLAGS.is_regression:
                metric_fn = regression_metric_fn
            else:
                metric_fn = metric_fn
            metric_args = [
                per_example_loss, label_ids, logits, is_real_example
            ]

            if FLAGS.use_tpu:
                eval_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    eval_metrics=(metric_fn, metric_args),
                    scaffold_fn=scaffold_fn)
            else:
                eval_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    eval_metric_ops=metric_fn(*metric_args))

            return eval_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:
            label_ids = tf.reshape(features["label_ids"], [-1])

            predictions = {
                "logits": logits,
                "labels": label_ids,
                "is_real": features["is_real_example"]
            }

            if FLAGS.use_tpu:
                output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    predictions=predictions,
                    scaffold_fn=scaffold_fn)
            else:
                output_spec = tf.estimator.EstimatorSpec(
                    mode=mode, predictions=predictions)
            return output_spec

        # Configuring the optimizer
        train_op, learning_rate, _ = model_utils.get_train_op(
            FLAGS, total_loss)

        monitor_dict = {}
        monitor_dict["lr"] = learning_rate

        # Constucting training TPUEstimatorSpec with new cache.
        if FLAGS.use_tpu:
            # Creating host calls
            if not FLAGS.is_regression:
                label_ids = tf.reshape(features['label_ids'], [-1])
                predictions = tf.argmax(logits,
                                        axis=-1,
                                        output_type=label_ids.dtype)
                is_correct = tf.equal(predictions, label_ids)
                accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32))

                monitor_dict["accuracy"] = accuracy

                host_call = function_builder.construct_scalar_host_call(
                    monitor_dict=monitor_dict,
                    model_dir=FLAGS.model_dir,
                    prefix="train/",
                    reduce_fn=tf.reduce_mean)
            else:
                host_call = None

            train_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                host_call=host_call,
                scaffold_fn=scaffold_fn)
        else:
            train_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                    loss=total_loss,
                                                    train_op=train_op)

        return train_spec
Пример #3
0
  def model_fn(features, labels, mode, params):
    #### Training or Evaluation
    is_training = (mode == tf.estimator.ModeKeys.TRAIN)

    #### Get loss from inputs
    outputs = function_builder.get_qa_outputs(FLAGS, features, is_training)

    #### Check model parameters
    num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
    tf.logging.info('#params: {}'.format(num_params))

    scaffold_fn = None

    #### Evaluation mode
    if mode == tf.estimator.ModeKeys.PREDICT:
      if FLAGS.init_checkpoint:
        tf.logging.info("init_checkpoint not being used in predict mode.")

      predictions = {
          "unique_ids": features["unique_ids"],
          "start_top_index": outputs["start_top_index"],
          "start_top_log_probs": outputs["start_top_log_probs"],
          "end_top_index": outputs["end_top_index"],
          "end_top_log_probs": outputs["end_top_log_probs"],
          "cls_logits": outputs["cls_logits"]
      }

      if FLAGS.use_tpu:
        output_spec = tf.contrib.tpu.TPUEstimatorSpec(
            mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
      else:
        output_spec = tf.estimator.EstimatorSpec(
            mode=mode, predictions=predictions)
      return output_spec

    ### Compute loss
    seq_length = tf.shape(features["input_ids"])[1]
    def compute_loss(log_probs, positions):
      one_hot_positions = tf.one_hot(
          positions, depth=seq_length, dtype=tf.float32)

      loss = - tf.reduce_sum(one_hot_positions * log_probs, axis=-1)
      loss = tf.reduce_mean(loss)
      return loss

    start_loss = compute_loss(
        outputs["start_log_probs"], features["start_positions"])
    end_loss = compute_loss(
        outputs["end_log_probs"], features["end_positions"])

    total_loss = (start_loss + end_loss) * 0.5

    cls_logits = outputs["cls_logits"]
    is_impossible = tf.reshape(features["is_impossible"], [-1])
    regression_loss = tf.nn.sigmoid_cross_entropy_with_logits(
        labels=is_impossible, logits=cls_logits)
    regression_loss = tf.reduce_mean(regression_loss)

    # note(zhiliny): by default multiply the loss by 0.5 so that the scale is
    # comparable to start_loss and end_loss
    total_loss += regression_loss * 0.5

    #### Configuring the optimizer
    train_op, learning_rate, _ = model_utils.get_train_op(FLAGS, total_loss)

    monitor_dict = {}
    monitor_dict["lr"] = learning_rate

    #### load pretrained models
    scaffold_fn = model_utils.init_from_checkpoint(FLAGS)

    #### Constucting training TPUEstimatorSpec with new cache.
    if FLAGS.use_tpu:
      host_call = function_builder.construct_scalar_host_call(
          monitor_dict=monitor_dict,
          model_dir=FLAGS.model_dir,
          prefix="train/",
          reduce_fn=tf.reduce_mean)

      train_spec = tf.contrib.tpu.TPUEstimatorSpec(
          mode=mode, loss=total_loss, train_op=train_op, host_call=host_call,
          scaffold_fn=scaffold_fn)
    else:
      train_spec = tf.estimator.EstimatorSpec(
          mode=mode, loss=total_loss, train_op=train_op)

    return train_spec
Пример #4
0
    def model_fn(features, labels, mode, params):
        # ### Training or Evaluation
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        return_dict = function_builder.get_classification_outputs(
            FLAGS, features, is_training)
        # per_example_loss = return_dict["per_example_loss"]
        cls_logits = return_dict["cls_logits"]
        # ### Check model parameters
        num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
        logger.info('#params: {}'.format(num_params))

        # ### load pretrained models
        scaffold_fn = model_utils.init_from_checkpoint(FLAGS)

        if mode == tf.estimator.ModeKeys.PREDICT:
            # label_ids = tf.reshape(features["cls"], [-1])
            predictions = {
                "feature_id": features["feature_id"],
                "cls_logits": cls_logits,
                # "cls": label_ids,
            }

            if FLAGS.use_tpu:
                output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    predictions=predictions,
                    scaffold_fn=scaffold_fn)
            else:
                output_spec = tf.estimator.EstimatorSpec(
                    mode=mode, predictions=predictions)
            return output_spec

        def compute_loss(log_probs, positions, depth):
            one_hot_positions = tf.one_hot(positions,
                                           depth=depth,
                                           dtype=tf.float32)

            loss = -tf.reduce_sum(one_hot_positions * log_probs, axis=-1)
            loss = tf.reduce_mean(loss)
            return loss

        cls_log_probs = return_dict["cls_log_probs"]
        num_choices = FLAGS.num_choices
        if num_choices:
            num_classes = num_choices
        else:
            num_classes = FLAGS.num_classes
        total_loss = compute_loss(cls_log_probs,
                                  features["cls"],
                                  depth=num_classes)

        # ### Configuring the optimizer
        train_op, learning_rate, _ = model_utils.get_train_op(
            FLAGS, total_loss)

        monitor_dict = {'loss/cls': total_loss, "lr": learning_rate}

        # ### Constucting training TPUEstimatorSpec with new cache.
        if FLAGS.use_tpu:
            # ### Creating host calls
            if not FLAGS.is_regression:
                label_ids = tf.reshape(features['cls'], [-1])
                predictions = tf.argmax(cls_logits,
                                        axis=-1,
                                        output_type=label_ids.dtype)
                is_correct = tf.equal(predictions, label_ids)
                accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32))

                monitor_dict["accuracy"] = accuracy

                host_call = function_builder.construct_scalar_host_call(
                    monitor_dict=monitor_dict,
                    model_dir=FLAGS.model_dir,
                    prefix="train/",
                    reduce_fn=tf.reduce_mean)
            else:
                host_call = None

            train_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                host_call=host_call,
                scaffold_fn=scaffold_fn)
        else:
            train_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                    loss=total_loss,
                                                    train_op=train_op)

        return train_spec
Пример #5
0
    def model_fn(features, labels, mode, params):
        #### Training or Evaluation
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        #### Get loss from inputs
        if FLAGS.is_regression:
            (total_loss, per_example_loss,
             logits) = function_builder.get_regression_loss(
                 FLAGS, features, is_training)
        else:
            (total_loss, per_example_loss,
             logits) = function_builder.get_classification_loss(
                 FLAGS, features, n_class, is_training)

        #### Check model parameters
        num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
        tf.logging.info('#params: {}'.format(num_params))

        #### load pretrained models
        scaffold_fn = model_utils.init_from_checkpoint(FLAGS)

        #### Evaluation mode
        if mode == tf.estimator.ModeKeys.EVAL:
            assert FLAGS.num_hosts == 1

            def metric_fn(per_example_loss, label_ids, logits,
                          is_real_example):
                predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
                eval_input_dict = {
                    'labels': label_ids,
                    'predictions': predictions,
                    'weights': is_real_example
                }
                accuracy = tf.metrics.accuracy(**eval_input_dict)

                loss = tf.metrics.mean(values=per_example_loss,
                                       weights=is_real_example)

                f1 = tf.contrib.metrics.f1_score(label_ids, predictions)

                #print('Label ids object type: {}'.format(type(label_ids)))
                #print('Predictions object type: {}'.format(type(predictions)))
                '''
        cm = tf.math.confusion_matrix(label_ids,predictions,num_classes=n_class)
        print('Converting confusion matrix into its values.')
        sess = tf.Session()
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        _cm = sess.run(cm)
        sess.close()
        print("Created value of confusion matrix: {}".format(_cm))
        '''
                '''
        sess = tf.Session()
        #sess.run(tf.global_variables_initializer())
        _cm = sess.run(cm)
        sess.close()
        print("Created value of confusion matrix: {}".format(_cm))
        '''
                '''
        This giant part below was supposed to calculate f1 precision etc but it failed because eval() and run() gives error.
        Error:
        tensorflow.python.framework.errors_impl.FailedPreconditionError: GetNext() failed because the iterator has not been initialized. Ensure that you have run the initializer operation for this iterator before getting the next element.
        [[node IteratorGetNext (defined at content/drive/My Drive/thesis/xlnet/run_classifier.py:866
        '''
                '''
        sess = tf.InteractiveSession()
        label_ids_np=label_ids.eval()
        predictions_np = predictions.eval()
        sess.close()

        print('Conversion succeeded.')
        print('Label_ids_np type: {}'.format(type(label_ids_np)))
        print('Predictions_np type: {}'.format(type(predictions_np)))

        sess = tf.InteractiveSession()
        print('Tf conversion: from {} to {} '.format(type(tf.constant([1,2,3])),type(tf.constant([1,2,3]).eval())))
        sess.close()
        
        #precision, recall, f1, _ = precision_recall_fscore_support(label_ids, predictions, average="macro", labels=list(range(0,n_class)))
        #mcc = matthews_corrcoef(label_ids_np, predictions_np)
        
        sess = tf.get_default_session()
        with sess.as_default():
            label_ids_np = label_ids.eval()
            predictions_np = predictions.eval()
        

        sess = tf.Session()
        sess.run(tf.global_variables_initializer())
        label_ids_np = sess.run(label_ids)
        predictions_np = sess.run(predictions)
        sess.close()

        precision_macro = scu.get_precision_macro(label_ids_np,predictions_np)
        recall_macro = scu.get_recall_macro(label_ids_np,predictions_np)
        f1_macro = scu.get_f1_macro(label_ids_np,predictions_np)
        mcc = scu.get_mcc_score(label_ids_np,predictions_np)
        print(f1_macro)
        print(mcc)
        '''

                return {'eval_accuracy': accuracy, 'eval_loss': loss, 'f1': f1}

            def regression_metric_fn(per_example_loss, label_ids, logits,
                                     is_real_example):
                loss = tf.metrics.mean(values=per_example_loss,
                                       weights=is_real_example)
                pearsonr = tf.contrib.metrics.streaming_pearson_correlation(
                    logits, label_ids, weights=is_real_example)
                return {'eval_loss': loss, 'eval_pearsonr': pearsonr}

            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)

            #### Constucting evaluation TPUEstimatorSpec with new cache.
            label_ids = tf.reshape(features['label_ids'], [-1])

            if FLAGS.is_regression:
                metric_fn = regression_metric_fn
            else:
                metric_fn = metric_fn
            metric_args = [
                per_example_loss, label_ids, logits, is_real_example
            ]

            if FLAGS.use_tpu:
                eval_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    eval_metrics=(metric_fn, metric_args),
                    scaffold_fn=scaffold_fn)
            else:
                eval_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    eval_metric_ops=metric_fn(*metric_args))

            return eval_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:
            label_ids = tf.reshape(features["label_ids"], [-1])

            predictions = {
                "logits": logits,
                "labels": label_ids,
                "is_real": features["is_real_example"]
            }

            if FLAGS.use_tpu:
                output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    predictions=predictions,
                    scaffold_fn=scaffold_fn)
            else:
                output_spec = tf.estimator.EstimatorSpec(
                    mode=mode, predictions=predictions)
            return output_spec

        #### Configuring the optimizer
        train_op, learning_rate, _ = model_utils.get_train_op(
            FLAGS, total_loss)

        monitor_dict = {}
        monitor_dict["lr"] = learning_rate

        #### Constucting training TPUEstimatorSpec with new cache.
        if FLAGS.use_tpu:
            #### Creating host calls
            if not FLAGS.is_regression:
                label_ids = tf.reshape(features['label_ids'], [-1])
                predictions = tf.argmax(logits,
                                        axis=-1,
                                        output_type=label_ids.dtype)
                is_correct = tf.equal(predictions, label_ids)
                accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32))

                monitor_dict["accuracy"] = accuracy

                host_call = function_builder.construct_scalar_host_call(
                    monitor_dict=monitor_dict,
                    model_dir=FLAGS.model_dir,
                    prefix="train/",
                    reduce_fn=tf.reduce_mean)
            else:
                host_call = None

            train_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                host_call=host_call,
                scaffold_fn=scaffold_fn)
        else:
            train_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                    loss=total_loss,
                                                    train_op=train_op)

        return train_spec
Пример #6
0
    def model_fn(features, labels, mode, params):
        #### Training or Evaluation
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        #### Get loss from inputs
        if FLAGS.is_regression:
            (total_loss, per_example_loss,
             logits) = function_builder.get_regression_loss(
                 FLAGS, features, is_training)
        else:
            flag_val_dict = {
                "dropout": FLAGS.dropout,
                "model_dir": FLAGS.model_dir,
                "data_dir": FLAGS.data_dir,
                "use_tpu": FLAGS.use_tpu,
                "num_core_per_host": FLAGS.num_core_per_host,
                "master": FLAGS.master,
                "iterations": FLAGS.iterations,
                "learning_rate": FLAGS.learning_rate,
                "train_batch_size": FLAGS.train_batch_size,
                "model_config_path": FLAGS.model_config_path,
            }
            for name in list(features.keys()):
                t = features[name]
                if t.dtype == tf.int64:
                    t = tf.cast(t, tf.int32)
                features[name] = t
            tf.logging.info(json.dumps(flag_val_dict))
            (total_loss, per_example_loss, logits,
             probabilities) = function_builder.get_classification_loss(
                 FLAGS, features, n_class, is_training)

        #### Check model parameters
        num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
        tf.logging.info('#params: {}'.format(num_params))

        #### load pretrained models
        scaffold_fn = model_utils.init_from_checkpoint(FLAGS)

        #### Evaluation mode
        if mode == tf.estimator.ModeKeys.EVAL:
            assert FLAGS.num_hosts == 1

            def metric_fn(per_example_loss, label_ids, logits,
                          is_real_example):
                predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
                eval_input_dict = {
                    'labels': label_ids,
                    'predictions': predictions,
                    'weights': is_real_example
                }
                accuracy = tf.metrics.accuracy(**eval_input_dict)

                loss = tf.metrics.mean(values=per_example_loss,
                                       weights=is_real_example)
                ###################################
                #  precision,recall, f1 score     #
                ###################################
                precision = metrics.precision(label_ids,
                                              predictions,
                                              20,
                                              average="macro")
                recall = metrics.recall(label_ids,
                                        predictions,
                                        20,
                                        average="macro")
                f = metrics.f1(label_ids, predictions, 20, average="macro")

                ###################################
                #      confusion matrix           #
                ###################################

                def eval_confusion_matrix(labels, predictions, num_classes):
                    with tf.variable_scope("eval_confusion_matrix"):
                        con_matrix = tf.confusion_matrix(
                            labels=labels,
                            predictions=predictions,
                            num_classes=num_classes)

                        con_matrix_sum = tf.Variable(
                            tf.zeros(shape=(num_classes, num_classes),
                                     dtype=tf.int32),
                            trainable=False,
                            name="confusion_matrix_result",
                            collections=[tf.GraphKeys.LOCAL_VARIABLES])
                        update_op = tf.assign_add(con_matrix_sum, con_matrix)
                        return tf.convert_to_tensor(con_matrix_sum), update_op

                return {
                    'eval_accuracy':
                    accuracy,
                    'eval_loss':
                    loss,
                    "eval_precision":
                    precision,
                    "eval_recall":
                    recall,
                    "eval_f":
                    f,
                    "conf_mat":
                    eval_confusion_matrix(label_ids,
                                          predictions,
                                          num_classes=20)
                }

            def regression_metric_fn(per_example_loss, label_ids, logits,
                                     is_real_example):
                loss = tf.metrics.mean(values=per_example_loss,
                                       weights=is_real_example)
                pearsonr = tf.contrib.metrics.streaming_pearson_correlation(
                    logits, label_ids, weights=is_real_example)
                return {'eval_loss': loss, 'eval_pearsonr': pearsonr}

            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)

            #### Constucting evaluation TPUEstimatorSpec with new cache.
            label_ids = tf.reshape(features['label_ids'], [-1])

            if FLAGS.is_regression:
                metric_fn = regression_metric_fn
            else:
                metric_fn = metric_fn
            metric_args = [
                per_example_loss, label_ids, logits, is_real_example
            ]

            if FLAGS.use_tpu:
                eval_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    eval_metrics=(metric_fn, metric_args),
                    scaffold_fn=scaffold_fn)
            else:
                eval_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    eval_metric_ops=metric_fn(*metric_args))

            return eval_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:
            label_ids = tf.reshape(features["label_ids"], [-1])

            predictions = {
                "logits": logits,
                "labels": label_ids,
                #           "is_real": features["is_real_example"]
            }

            if FLAGS.use_tpu:
                output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    predictions={"probabilities": probabilities},
                    scaffold_fn=scaffold_fn)
            else:
                output_spec = tf.estimator.EstimatorSpec(
                    mode=mode, predictions={"probabilities": probabilities})
            return output_spec
        #### Configuring the optimizer
        train_op, learning_rate, _ = model_utils.get_train_op(
            FLAGS, total_loss)

        monitor_dict = {}
        monitor_dict["lr"] = learning_rate

        #### Constucting training TPUEstimatorSpec with new cache.
        if FLAGS.use_tpu:
            #### Creating host calls
            if not FLAGS.is_regression:
                label_ids = tf.reshape(features['label_ids'], [-1])
                predictions = tf.argmax(logits,
                                        axis=-1,
                                        output_type=label_ids.dtype)
                is_correct = tf.equal(predictions, label_ids)
                accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32))

                monitor_dict["accuracy"] = accuracy

                host_call = function_builder.construct_scalar_host_call(
                    monitor_dict=monitor_dict,
                    model_dir=FLAGS.model_dir,
                    prefix="train/",
                    reduce_fn=tf.reduce_mean)
            else:
                host_call = None

            train_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                host_call=host_call,
                scaffold_fn=scaffold_fn)
        else:
            train_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                    loss=total_loss,
                                                    train_op=train_op)

        return train_spec