示例#1
0
    def model_fn(features, labels, mode, params):
        """The `model_fn` for TPUEstimator."""
        utils.log("Building model...")
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        model = FinetuningModel(config, tasks, is_training, features,
                                num_train_steps)

        # Load pre-trained weights from checkpoint
        init_checkpoint = config.init_checkpoint
        if pretraining_config is not None:
            init_checkpoint = tf.train.latest_checkpoint(
                pretraining_config.model_dir)
            utils.log("Using checkpoint", init_checkpoint)
        tvars = tf.trainable_variables()
        scaffold_fn = None
        if init_checkpoint:
            assignment_map, _ = modeling.get_assignment_map_from_checkpoint(
                tvars, init_checkpoint)
            if config.use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        # Build model for training or prediction
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer(
                model.loss,
                config.learning_rate,
                num_train_steps,
                weight_decay_rate=config.weight_decay_rate,
                use_tpu=config.use_tpu,
                warmup_proportion=config.warmup_proportion,
                layerwise_lr_decay_power=config.layerwise_lr_decay,
                n_transformer_layers=model.bert_config.num_hidden_layers)
            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=model.loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn,
                training_hooks=[
                    training_utils.ETAHook(
                        {} if config.use_tpu else dict(loss=model.loss),
                        num_train_steps, config.iterations_per_loop,
                        config.use_tpu, 10)
                ])
        else:
            assert mode == tf.estimator.ModeKeys.PREDICT
            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                predictions=utils.flatten_dict(model.outputs),
                scaffold_fn=scaffold_fn)

        utils.log("Building complete")
        return output_spec
示例#2
0
    def _train(self, outputs, labels):
        prediction = outputs["prediction"]
        loss = outputs["loss"]
        train_op = optimization.create_optimizer(
            loss, self.bert_config["learning_rate"],
            self.bert_config["num_train_steps"],
            self.bert_config["num_warmup_steps"], self.bert_config["use_tpu"])

        output_spec = tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.TRAIN,
            loss=outputs['loss'],
            train_op=train_op,
            training_hooks=[
                tf.train.LoggingTensorHook(
                    {
                        "loss":
                        loss,
                        "step":
                        tf.train.get_global_step(),
                        "acc":
                        100. * tf.reduce_mean(
                            tf.cast(
                                tf.equal(tf.cast(prediction, tf.int32),
                                         tf.cast(labels, tf.int32)),
                                tf.float32))
                    },
                    every_n_iter=100)
            ])
        return output_spec
示例#3
0
 def model_fn(features, labels, mode, params):
   """Build the model for training."""
   model = PretrainingModel(config, features,
                            mode == tf.estimator.ModeKeys.TRAIN)
   utils.log("Model is built!")
   if mode == tf.estimator.ModeKeys.TRAIN:
     train_op = optimization.create_optimizer(
         model.total_loss, config.learning_rate, config.num_train_steps,
         weight_decay_rate=config.weight_decay_rate,
         use_tpu=config.use_tpu,
         warmup_steps=config.num_warmup_steps,
         lr_decay_power=config.lr_decay_power
     )
     output_spec = tf.estimator.tpu.TPUEstimatorSpec(
         mode=mode,
         loss=model.total_loss,
         train_op=train_op,
         training_hooks=[training_utils.ETAHook(
             {} if config.use_tpu else dict(loss=model.total_loss),
             config.num_train_steps, config.iterations_per_loop,
             config.use_tpu)]
     )
   elif mode == tf.estimator.ModeKeys.EVAL:
     output_spec = tf.estimator.tpu.TPUEstimatorSpec(
         mode=mode,
         loss=model.total_loss,
         eval_metrics=model.eval_metrics,
         evaluation_hooks=[training_utils.ETAHook(
             {} if config.use_tpu else dict(loss=model.total_loss),
             config.num_eval_steps, config.iterations_per_loop,
             config.use_tpu, is_training=False)])
   else:
     raise ValueError("Only TRAIN and EVAL modes are supported")
   return output_spec
示例#4
0
    def model_fn(features, labels, mode, params):
        """The `model_fn` for TPUEstimator."""
        utils.log("Building model...")
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        model = FinetuningModel(config, tasks, is_training, features,
                                num_train_steps)

        # Load pre-trained weights from checkpoint
        init_checkpoint = config.init_checkpoint
        if pretraining_config is not None:
            init_checkpoint = tf.train.latest_checkpoint(
                pretraining_config.model_dir)
            utils.log("Using checkpoint", init_checkpoint)
        tvars = tf.trainable_variables()
        scaffold_fn = None
        initialized_variable_names = {}
        if init_checkpoint:
            utils.log("Using checkpoint", init_checkpoint)
            assignment_map, initialized_variable_names = modeling.get_assignment_map_from_checkpoint(
                tvars, init_checkpoint)
            tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        utils.log("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            utils.logerr("  name = %s, shape = %s%s", var.name, var.shape,
                         init_string)

        # Build model for training or prediction
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer(
                model.loss,
                config.learning_rate,
                num_train_steps,
                weight_decay_rate=config.weight_decay_rate,
                warmup_proportion=config.warmup_proportion,
                n_transformer_layers=model.bert_config.num_hidden_layers)
            output_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                loss=model.loss,
                train_op=train_op,
                training_hooks=[
                    training_utils.ETAHook(
                        {} if config.use_tpu else dict(loss=model.loss),
                        num_train_steps, config.iterations_per_loop,
                        config.use_tpu, 10)
                ])
        else:
            assert mode == tf.estimator.ModeKeys.PREDICT
            output_spec = tf.estimator.EstimatorSpec(
                mode=mode, predictions=utils.flatten_dict(model.outputs))

        utils.log("Building complete")
        return output_spec
示例#5
0
    def model_fn(features, labels, mode, params):
        """Build the model for training."""
        model = PretrainingModel(config, features,
                                 mode == tf.estimator.ModeKeys.TRAIN)
        utils.log("Model is built!")
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op, optimizer = optimization.create_optimizer(
                model.total_loss,
                config.learning_rate,
                config.num_train_steps,
                weight_decay_rate=config.weight_decay_rate,
                use_tpu=config.use_tpu,
                warmup_steps=config.num_warmup_steps,
                lr_decay_power=config.lr_decay_power)

            eta_hook = training_utils.ETAHook({} if config.use_tpu else dict(
                Total_loss=model.total_loss,
                MLM_loss=model.mlm_output_loss,
                RTD_loss=model.disc_output_loss,
                learning_rate=optimizer.learning_rate,
                MLM_accuracy=model.metrics['masked_lm_accuracy'],
                Sampled_MLM_accuracy=model.
                metrics['sampled_masked_lm_accuracy'],
                RTD_accuracy=model.metrics['disc_accuracy'],
                RTD_precision=model.metrics['disc_precision'],
                RTD_recall=model.metrics['disc_recall'],
                RTD_auc=model.metrics['disc_auc'],
            ), config.num_train_steps, config.iterations_per_loop,
                                              config.use_tpu)

            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=model.total_loss,
                train_op=train_op,
                training_hooks=[eta_hook])
        elif mode == tf.estimator.ModeKeys.EVAL:
            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=model.total_loss,
                eval_metrics=model.eval_metrics,
                evaluation_hooks=[
                    training_utils.ETAHook({} if config.use_tpu else dict(
                        loss=model.total_loss,
                        mlm_loss=model.mlm_output_loss,
                        disc_loss=model.disc_output_loss),
                                           config.num_eval_steps,
                                           config.iterations_per_loop,
                                           config.use_tpu,
                                           is_training=False)
                ])
        else:
            raise ValueError("Only TRAIN and EVAL modes are supported")
        return output_spec
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        segment_ids = features["segment_ids"]
        input_ids = features['masked_input']
        input_mask = tf.cast(features['pad_mask'], dtype=tf.int32)

        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_lm_weights = features["masked_lm_weights"]

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        model = modeling_funnelformer.FunnelTFM(
            bert_config=bert_config,
            is_training=is_training,
            input_ids=input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=use_one_hot_embeddings)

        (masked_lm_loss, masked_lm_example_loss,
         masked_lm_log_probs) = get_masked_lm_output(
             bert_config, model.get_sequence_output(),
             model.get_embedding_table(), masked_lm_positions, masked_lm_ids,
             masked_lm_weights)

        masked_lm_preds = tf.argmax(masked_lm_log_probs,
                                    axis=-1,
                                    output_type=tf.int32)

        total_loss = masked_lm_loss
        monitor_dict = {}

        tvars = tf.trainable_variables()
        for tvar in tvars:
            print(tvar, "=====tvar=====")

        eval_fn_inputs = {
            "masked_lm_preds": masked_lm_preds,
            "masked_lm_loss": masked_lm_example_loss,
            "masked_lm_weights": masked_lm_weights,
            "masked_lm_ids": masked_lm_ids
        }

        eval_fn_keys = eval_fn_inputs.keys()
        eval_fn_values = [eval_fn_inputs[k] for k in eval_fn_keys]

        def monitor_fn(eval_fn_inputs, keys):
            # d = {k: arg for k, arg in zip(eval_fn_keys, args)}
            d = {}
            for key in eval_fn_inputs:
                if key in keys:
                    d[key] = eval_fn_inputs[key]
            monitor_dict = dict()
            masked_lm_ids = tf.reshape(d["masked_lm_ids"], [-1])
            masked_lm_preds = tf.reshape(d["masked_lm_preds"], [-1])
            masked_lm_weights = tf.reshape(d["masked_lm_weights"], [-1])
            print(masked_lm_preds, "===masked_lm_preds===")
            print(masked_lm_ids, "===masked_lm_ids===")
            print(masked_lm_weights, "===masked_lm_weights===")
            # masked_lm_pred_ids = tf.argmax(masked_lm_preds, axis=-1,
            #                             output_type=tf.int32)
            masked_lm_acc = tf.cast(tf.equal(masked_lm_preds, masked_lm_ids),
                                    dtype=tf.float32)
            masked_lm_acc = tf.reduce_sum(
                masked_lm_acc * tf.cast(masked_lm_weights, dtype=tf.float32))
            masked_lm_acc /= (
                1e-10 +
                tf.reduce_sum(tf.cast(masked_lm_weights, dtype=tf.float32)))

            masked_lm_loss = tf.reshape(d["masked_lm_loss"], [-1])
            masked_lm_loss = tf.reduce_sum(
                masked_lm_loss * tf.cast(masked_lm_weights, dtype=tf.float32))
            masked_lm_loss /= (
                1e-10 +
                tf.reduce_sum(tf.cast(masked_lm_weights, dtype=tf.float32)))

            monitor_dict['masked_lm_loss'] = masked_lm_loss
            monitor_dict['masked_lm_acc'] = masked_lm_acc

            return monitor_dict

        monitor_dict = monitor_fn(eval_fn_inputs, eval_fn_keys)

        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = modeling_funnelformer.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:

            train_op, output_learning_rate = optimization.create_optimizer(
                total_loss,
                learning_rate,
                num_train_steps,
                weight_decay_rate=FLAGS.weight_decay_rate,
                use_tpu=use_tpu,
                warmup_steps=num_warmup_steps,
                lr_decay_power=FLAGS.lr_decay_power)

            monitor_dict['learning_rate'] = output_learning_rate
            if FLAGS.monitoring and monitor_dict:
                host_call = log_utils.construct_scalar_host_call_v1(
                    monitor_dict=monitor_dict,
                    model_dir=FLAGS.output_dir,
                    prefix="train/")
            else:
                host_call = None

            print(host_call, "====host_call====")

            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn,
                host_call=host_call)
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(masked_lm_example_loss, masked_lm_log_probs,
                          masked_lm_ids, masked_lm_weights):
                """Computes the loss and accuracy of the model."""
                masked_lm_log_probs = tf.reshape(
                    masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]])
                masked_lm_predictions = tf.argmax(masked_lm_log_probs,
                                                  axis=-1,
                                                  output_type=tf.int32)
                masked_lm_example_loss = tf.reshape(masked_lm_example_loss,
                                                    [-1])
                masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
                masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
                masked_lm_accuracy = tf.metrics.accuracy(
                    labels=masked_lm_ids,
                    predictions=masked_lm_predictions,
                    weights=masked_lm_weights)
                masked_lm_mean_loss = tf.metrics.mean(
                    values=masked_lm_example_loss, weights=masked_lm_weights)

                return {
                    "masked_lm_accuracy": masked_lm_accuracy,
                    "masked_lm_loss": masked_lm_mean_loss
                }

            eval_metrics = (metric_fn, [
                masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
                masked_lm_weights
            ])
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        else:
            raise ValueError("Only TRAIN and EVAL modes are supported: %s" %
                             (mode))

        return output_spec
示例#7
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" % (name, features[name].shape))
        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        keyword_mask = features["keyword_mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        embedding_table = tf.get_variable("embedding_table",
                                          shape=[model_config.vocab_size, model_config.vocab_vec_size],
                                          trainable=embedding_table_trainable)

        def init_embedding_table(scoffold,sess):
            sess.run(embedding_table.initializer, {embedding_table.initial_value: embedding_table_value})

        if embedding_table_value is not None:
            scaffold = tf.train.Scaffold(init_fn=init_embedding_table)
        else:
            scaffold = None

        (total_loss, text_representation, keyword_probs) = create_model(model_config,
                                                                       is_training,
                                                                       input_ids,
                                                                       input_mask,
                                                                       keyword_mask,
                                                                       segment_ids,
                                                                       embedding_table)

        tvars = tf.trainable_variables()
        initialized_variable_names = {}

        if init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = get_assignment_map_from_checkpoint(tvars, init_checkpoint)
            tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer(
                total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu=False)

            log_hook = tf.train.LoggingTensorHook({"total_loss":total_loss}, every_n_iter=10)
            output_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                training_hooks=[log_hook],
                scaffold=scaffold)
        elif mode == tf.estimator.ModeKeys.EVAL:
            #def metric_fn(per_example_loss, label_ids, logits):
            #    predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
            #    accuracy = tf.metrics.accuracy(
            #        labels=label_ids, predictions=predictions)
            #    loss = tf.metrics.mean(values=per_example_loss)
            #    return {
            #        "eval_accuracy": accuracy,
            #        "eval_loss": loss,
            #    }

            #eval_metrics = (metric_fn,
            #                [per_example_loss, label_ids, logits])
            #predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
            #eval_accuracy = tf.metrics.accuracy(labels=label_ids, predictions=predictions)
            #eval_loss = tf.metrics.mean(values=per_example_loss)
            #output_spec = tf.estimator.EstimatorSpec(
            #    mode=mode,
            #    loss=total_loss,
            #    eval_metric_ops={"eval_accuracy":eval_accuracy, "eval_loss":eval_loss},
            #    scaffold=scaffold)
            pass
        else:
            output_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                predictions={"text_representation": text_representation, "input_ids":input_ids, "keyword_probs":keyword_probs},
                prediction_hooks=None,
                scaffold=scaffold)
        return output_spec
示例#8
0
    def model_fn(self, features, labels, mode):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        bert_config = self.config
        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        qids = features['qids']
        input_ids = features["input_ids_list"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids_list"]
        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_lm_weights = tf.cast(features["masked_lm_weights"], tf.float32)
        next_sentence_labels = labels

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        print('********************* is training : ', is_training)

        model = bert_base.BertModel(
            config=bert_config,
            is_training=is_training,
            input_ids=input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=bert_config["use_one_hot_embeddings"])

        (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs) =\
            get_masked_lm_output(
            bert_config, model.get_sequence_output(), model.get_embedding_table(),
            masked_lm_positions, masked_lm_ids, masked_lm_weights)

        (next_sentence_loss, next_sentence_example_loss,next_sentence_log_probs) =\
            get_cate_prediction_output(
            bert_config, model.get_pooled_output(), next_sentence_labels)

        total_loss = masked_lm_loss + next_sentence_loss

        eval_metrics = self._metric_fn(masked_lm_example_loss,
                                       masked_lm_log_probs, masked_lm_ids,
                                       masked_lm_weights,
                                       next_sentence_example_loss,
                                       next_sentence_log_probs,
                                       next_sentence_labels)

        tvars = tf.trainable_variables()

        initialized_variable_names = {}
        scaffold_fn = None
        if bert_config["init_checkpoint"]:
            (assignment_map, initialized_variable_names) = \
                bert_base.get_assignment_map_from_checkpoint(tvars, bert_config["init_checkpoint"])
            tf.train.init_from_checkpoint(bert_config["init_checkpoint"],
                                          assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer(
                total_loss, bert_config["learning_rate"],
                bert_config["num_train_steps"],
                bert_config["num_warmup_steps"], bert_config["use_tpu"])

            output_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                training_hooks=[
                    tf.train.LoggingTensorHook(
                        {
                            "loss":
                            total_loss,
                            "step":
                            tf.train.get_global_step(),
                            "masked_lm_acc":
                            eval_metrics["masked_lm_accuracy"],
                            "masked_lm_loss":
                            eval_metrics["masked_lm_loss"],
                            "next_sentence_acc":
                            eval_metrics["next_sentence_accuracy"],
                            "next_sentence_loss":
                            eval_metrics["next_sentence_loss"]
                        },
                        every_n_iter=100)
                ])
        elif mode == tf.estimator.ModeKeys.PREDICT:
            outputs = dict(oneid=qids)
            if bert_config['out_embedding_type'] == "sequence_output":
                outputs['out_embedding'] = model.get_sequence_output()
            elif bert_config['out_embedding_type'] == "pooled_output":
                outputs['out_embedding'] = model.get_pooled_output()
            else:
                raise ValueError("Not recognized output embedding")
            output_spec = tf.estimator.EstimatorSpec(
                mode=tf.estimator.ModeKeys.PREDICT, predictions=outputs)
        else:
            raise ValueError("Only TRAIN and EVAL modes are supported: %s" %
                             (mode))

        slim.model_analyzer.analyze_vars(tf.trainable_variables(),
                                         print_info=True)
        return output_spec
示例#9
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        input_ids_a = features["input_ids_a"]
        input_mask_a = features["input_mask_a"]
        segment_ids_a = features["segment_ids_a"]

        if do_encode == False:
            input_ids_b = features["input_ids_b"]
            input_mask_b = features["input_mask_b"]
            segment_ids_b = features["segment_ids_b"]
            label = features["label"]

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        embedding_table = tf.get_variable(
            "embedding_table",
            shape=[model_config.vocab_size, model_config.vocab_vec_size],
            trainable=embedding_table_trainable)

        def init_embedding_table(scoffold, sess):
            sess.run(embedding_table.initializer,
                     {embedding_table.initial_value: embedding_table_value})

        if embedding_table_value is not None:
            scaffold = tf.train.Scaffold(init_fn=init_embedding_table)
        else:
            scaffold = None

        if do_encode:
            text_representation = create_encode_model(model_config,
                                                      is_training, input_ids_a,
                                                      input_mask_a,
                                                      segment_ids_a,
                                                      embedding_table)
        else:
            (total_loss, cosine) = create_similar_model(
                model_config, is_training, input_ids_a, input_mask_a,
                segment_ids_a, input_ids_b, input_mask_b, segment_ids_b, label,
                embedding_table)

        tvars = tf.trainable_variables()
        initialized_variable_names = {}

        if init_checkpoint:
            (assignment_map,
             initialized_variable_names) = get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if do_encode == False:
            if mode == tf.estimator.ModeKeys.TRAIN:
                train_op = optimization.create_optimizer(total_loss,
                                                         learning_rate,
                                                         num_train_steps,
                                                         num_warmup_steps,
                                                         use_tpu=False)

                log_hook = tf.train.LoggingTensorHook(
                    {"total_loss": total_loss}, every_n_iter=100)
                output_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    train_op=train_op,
                    training_hooks=[log_hook],
                    scaffold=scaffold)
            elif mode == tf.estimator.ModeKeys.EVAL:
                pass
            else:
                output_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    predictions={"cosine": cosine},
                    prediction_hooks=None,
                    scaffold=scaffold)
        else:
            output_spec = tf.estimator.EstimatorSpec(
                mode=tf.estimator.ModeKeys.PREDICT,
                predictions={"text_representation": text_representation},
                prediction_hooks=None,
                scaffold=scaffold)
        return output_spec
    def model_fn(features, labels, mode, params):
        """The `model_fn` for Estimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]

        input_ids1 = features["input_ids1"]
        input_mask1 = features["input_mask1"]
        segment_ids1 = features["segment_ids1"]

        input_ids2 = features["input_ids2"]
        input_mask2 = features["input_mask2"]
        segment_ids2 = features["segment_ids2"]

        input_ids3 = features["input_ids3"]
        input_mask3 = features["input_mask3"]
        segment_ids3 = features["segment_ids3"]

        label_ids = features['label_ids']

        with tf.get_default_graph().as_default():
            tf.set_random_seed(1234)
            with tf.name_scope('replica_0'), \
              tf.variable_scope('replica_0'):
                with tf.device('/gpu:0'):
                    query, q_h = create_model(
                        bert_config,
                        is_training,
                        input_ids,
                        input_mask,
                        segment_ids,
                        use_one_hot_embeddings=use_one_hot_embeddings)
                    query = tf.expand_dims(query, axis=1)
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  {'bert/': 'replica_0/bert/'})

            with tf.device('/gpu:0'):
                cand1, cand1_h = create_model(
                    bert_config,
                    is_training,
                    input_ids1,
                    input_mask1,
                    segment_ids1,
                    use_one_hot_embeddings=use_one_hot_embeddings)
                cand1 = tf.expand_dims(cand1, axis=1)

            with tf.device('/gpu:0'):
                cand2, cand2_h = create_model(
                    bert_config,
                    is_training,
                    input_ids2,
                    input_mask2,
                    segment_ids2,
                    use_one_hot_embeddings=use_one_hot_embeddings)
                cand2 = tf.expand_dims(cand2, axis=1)

            with tf.device('/gpu:0'):
                cand3, cand3_h = create_model(
                    bert_config,
                    is_training,
                    input_ids3,
                    input_mask3,
                    segment_ids3,
                    use_one_hot_embeddings=use_one_hot_embeddings)
                cand3 = tf.expand_dims(cand3, axis=1)

            with tf.device('/gpu:0'):
                cands = tf.concat([cand1, cand2, cand3], axis=1)
                scores = tf.matmul(query, cands, transpose_b=True)
                logits = tf.squeeze(scores, axis=1)

                log_probs = tf.nn.log_softmax(logits, axis=-1)

                one_hot_labels = tf.one_hot(label_ids,
                                            depth=num_labels,
                                            dtype=tf.float32,
                                            name="output_labels")

                per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs,
                                                  axis=-1)
                total_loss = tf.reduce_mean(per_example_loss)

        tvars = tf.trainable_variables()
        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            tf.logging.info("  name = %s, shape = %s", var.name, var.shape)

        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer(total_loss, learning_rate,
                                                     num_train_steps,
                                                     num_warmup_steps)

            output_spec = tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
                                                          loss=total_loss,
                                                          train_op=train_op)
        elif mode == tf.estimator.ModeKeys.PREDICT:
            predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
                                                          predictions={
                                                              'predictions':
                                                              predictions,
                                                          })
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, label_ids, logits):
                predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
                accuracy = tf.metrics.accuracy(label_ids, predictions)
                loss = tf.metrics.mean(per_example_loss)
                return {
                    "eval_accuracy": accuracy,
                    "eval_loss": loss,
                }

            eval_metrics = (metric_fn, [per_example_loss, label_ids, logits])
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode, loss=total_loss, eval_metrics=eval_metrics)
        else:
            raise ValueError("Only TRAIN and EVAL modes are supported: %s" %
                             (mode))

        return output_spec
  def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
    """The `model_fn` for TPUEstimator."""

    tf.logging.info("*** Features ***")
    for name in sorted(features.keys()):
      tf.logging.info("  name = %s, shape = %s" % (name, features[name].shape))

    input_ids = features["input_ids"]
    input_mask = features["input_mask"]
    segment_ids = features["segment_ids"]
    masked_lm_positions = features["masked_lm_positions"]
    masked_lm_ids = features["masked_lm_ids"]
    masked_lm_weights = features["masked_lm_weights"]
    next_sentence_labels = features["next_sentence_labels"]

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)

    model = modeling.BertModel(
        config=bert_config,
        is_training=is_training,
        input_ids=input_ids,
        input_mask=input_mask,
        token_type_ids=segment_ids,
        use_one_hot_embeddings=use_one_hot_embeddings)

    (masked_lm_loss,
     masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
         bert_config, model.get_sequence_output(), model.get_embedding_table(),
         masked_lm_positions, masked_lm_ids, masked_lm_weights)

    (next_sentence_loss, next_sentence_example_loss,
     next_sentence_log_probs) = get_next_sentence_output(
         bert_config, model.get_pooled_output(), next_sentence_labels)

    total_loss = masked_lm_loss + next_sentence_loss

    tvars = tf.trainable_variables()

    initialized_variable_names = {}
    scaffold_fn = None
    if init_checkpoint:
      (assignment_map, initialized_variable_names
      ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
      if use_tpu:

        def tpu_scaffold():
          tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
          return tf.train.Scaffold()

        scaffold_fn = tpu_scaffold
      else:
        tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

    tf.logging.info("**** Trainable Variables ****")
    for var in tvars:
      init_string = ""
      if var.name in initialized_variable_names:
        init_string = ", *INIT_FROM_CKPT*"
      tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                      init_string)

    output_spec = None
    if mode == tf.estimator.ModeKeys.TRAIN:
      train_op = optimization.create_optimizer(
          total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)

      output_spec = tf.estimator.tpu.TPUEstimatorSpec(
          mode=mode,
          loss=total_loss,
          train_op=train_op,
          scaffold_fn=scaffold_fn)
    elif mode == tf.estimator.ModeKeys.EVAL:

      def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
                    masked_lm_weights, next_sentence_example_loss,
                    next_sentence_log_probs, next_sentence_labels):
        """Computes the loss and accuracy of the model."""
        masked_lm_log_probs = tf.reshape(masked_lm_log_probs,
                                         [-1, masked_lm_log_probs.shape[-1]])
        masked_lm_predictions = tf.argmax(
            masked_lm_log_probs, axis=-1, output_type=tf.int32)
        masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
        masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
        masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
        masked_lm_accuracy = tf.metrics.accuracy(
            labels=masked_lm_ids,
            predictions=masked_lm_predictions,
            weights=masked_lm_weights)
        masked_lm_mean_loss = tf.metrics.mean(
            values=masked_lm_example_loss, weights=masked_lm_weights)

        next_sentence_log_probs = tf.reshape(
            next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]])
        next_sentence_predictions = tf.argmax(
            next_sentence_log_probs, axis=-1, output_type=tf.int32)
        next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
        next_sentence_accuracy = tf.metrics.accuracy(
            labels=next_sentence_labels, predictions=next_sentence_predictions)
        next_sentence_mean_loss = tf.metrics.mean(
            values=next_sentence_example_loss)

        return {
            "masked_lm_accuracy": masked_lm_accuracy,
            "masked_lm_loss": masked_lm_mean_loss,
            "next_sentence_accuracy": next_sentence_accuracy,
            "next_sentence_loss": next_sentence_mean_loss,
        }

      eval_metrics = (metric_fn, [
          masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
          masked_lm_weights, next_sentence_example_loss,
          next_sentence_log_probs, next_sentence_labels
      ])
      output_spec = tf.estimator.tpu.TPUEstimatorSpec(
          mode=mode,
          loss=total_loss,
          eval_metrics=eval_metrics,
          scaffold_fn=scaffold_fn)
    else:
      raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))

    return output_spec
        def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
            from tensorflow.python.estimator.model_fn import EstimatorSpec

            tf.logging.info("*** Features ***")
            for name in sorted(features.keys()):
                tf.logging.info("  name = %s, shape = %s" % (name, features[name].shape))

            input_ids = features["input_ids"]
            input_mask = features["input_mask"]
            segment_ids = features["segment_ids"]
            label_ids = features["label_ids"]

            is_training = (mode == tf.estimator.ModeKeys.TRAIN)

            (total_loss, per_example_loss, logits, probabilities) = self.create_model(
                bert_config, is_training, input_ids, input_mask, segment_ids, label_ids,
                num_labels, use_one_hot_embeddings)

            tvars = tf.trainable_variables()
            initialized_variable_names = {}

            if init_checkpoint:
                (assignment_map, initialized_variable_names) \
                    = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

            tf.logging.info("**** Trainable Variables ****")
            for var in tvars:
                init_string = ""
                if var.name in initialized_variable_names:
                    init_string = ", *INIT_FROM_CKPT*"
                tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                                init_string)

            if mode == tf.estimator.ModeKeys.TRAIN:

                train_op = optimization.create_optimizer(
                    total_loss, learning_rate, num_train_steps, num_warmup_steps, False)

                output_spec = EstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    train_op=train_op)
            elif mode == tf.estimator.ModeKeys.EVAL:

                def metric_fn(per_example_loss, label_ids, logits):
                    predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
                    accuracy = tf.metrics.accuracy(label_ids, predictions)
                    auc = tf.metrics.auc(label_ids, predictions)
                    loss = tf.metrics.mean(per_example_loss)
                    return {
                        "eval_accuracy": accuracy,
                        "eval_auc": auc,
                        "eval_loss": loss,
                    }

                eval_metrics = metric_fn(per_example_loss, label_ids, logits)
                output_spec = EstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    eval_metric_ops=eval_metrics)
            else:
                output_spec = EstimatorSpec(mode=mode, predictions=probabilities)

            return output_spec
示例#13
0
    def model_fn(features, labels, mode, params):
        """Build the model for training."""
        if config.masking_strategy == pretrain_helpers.ADVERSARIAL_STRATEGY or config.masking_strategy == pretrain_helpers.MIX_ADV_STRATEGY:
            model = AdversarialPretrainingModel(
                config, features, mode == tf.estimator.ModeKeys.TRAIN)
        elif config.masking_strategy == pretrain_helpers.RW_STRATEGY:
            ratio = []
            with open(config.ratio_file, "r") as fin:
                for line in fin:
                    line = line.strip()
                    if line:
                        tok = line.split()
                        ratio.append(float(tok[1]))
            model = RatioBasedPretrainingModel(
                config, features, ratio, mode == tf.estimator.ModeKeys.TRAIN)
        else:
            model = PretrainingModel(config, features,
                                     mode == tf.estimator.ModeKeys.TRAIN)
        utils.log("Model is built!")

        tvars = tf.trainable_variables()

        initialized_variable_names = {}
        if config.init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = modeling.get_assignment_map_from_checkpoint(
                 tvars, config.init_checkpoint)
            tf.train.init_from_checkpoint(config.init_checkpoint,
                                          assignment_map)

        utils.log("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            utils.log("  name = %s, shape = %s%s", var.name, var.shape,
                      init_string)

        if mode == tf.estimator.ModeKeys.TRAIN:
            if config.masking_strategy == pretrain_helpers.ADVERSARIAL_STRATEGY:
                student_train_op = optimization.create_optimizer(
                    model.mlm_loss,
                    config.learning_rate,
                    config.num_train_steps,
                    weight_decay_rate=config.weight_decay_rate,
                    use_tpu=config.use_tpu,
                    warmup_steps=config.num_warmup_steps,
                    lr_decay_power=config.lr_decay_power)
                teacher_train_op = optimization.create_optimizer(
                    model.teacher_loss,
                    config.teacher_learning_rate,
                    config.num_train_steps,
                    lr_decay_power=config.lr_decay_power)
                train_op = tf.group(student_train_op, teacher_train_op)
                output_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=model.total_loss,
                    train_op=train_op,
                    training_hooks=[
                        training_utils.ETAHook(
                            dict(loss=model.mlm_loss,
                                 teacher_loss=model.teacher_loss,
                                 reward=model._baseline),
                            config.num_train_steps, config.iterations_per_loop,
                            config.use_tpu)
                    ])
            else:
                train_op = optimization.create_optimizer(
                    model.total_loss,
                    config.learning_rate,
                    config.num_train_steps,
                    weight_decay_rate=config.weight_decay_rate,
                    use_tpu=config.use_tpu,
                    warmup_steps=config.num_warmup_steps,
                    lr_decay_power=config.lr_decay_power)
                output_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=model.total_loss,
                    train_op=train_op,
                    training_hooks=[
                        training_utils.ETAHook(dict(loss=model.total_loss),
                                               config.num_train_steps,
                                               config.iterations_per_loop,
                                               config.use_tpu)
                    ])
        elif mode == tf.estimator.ModeKeys.EVAL:
            output_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                loss=model.total_loss,
                eval_metric_ops=model.eval_metrics,
                evaluation_hooks=[
                    training_utils.ETAHook(dict(loss=model.total_loss),
                                           config.num_eval_steps,
                                           config.iterations_per_loop,
                                           config.use_tpu,
                                           is_training=False)
                ])
        else:
            raise ValueError("Only TRAIN and EVAL modes are supported")
        return output_spec
示例#14
0
    def model_fn(features, labels, mode, params):
        """Build the model for training."""
        model = PretrainingModel(
            config=config,
            features=features,
            is_training=mode == tf.estimator.ModeKeys.TRAIN,
            init_checkpoint=config.init_checkpoint)
        utils.log("Model is built!")
        to_log = {
            "gen_loss": model.mlm_output.loss,
            "disc_loss": model.disc_output.loss,
            "total_loss": model.total_loss
        }
        if mode == tf.estimator.ModeKeys.TRAIN:

            tf.summary.scalar('gen_loss', model.mlm_output.loss)
            tf.summary.scalar('disc_loss', model.disc_output.loss)
            tf.summary.scalar('total_loss', model.total_loss)

            lr_multiplier = hvd.size() if config.scale_lr else 1
            train_op = optimization.create_optimizer(
                loss=model.total_loss,
                learning_rate=config.learning_rate * lr_multiplier,
                num_train_steps=config.num_train_steps,
                weight_decay_rate=config.weight_decay_rate,
                warmup_steps=config.num_warmup_steps,
                warmup_proportion=0,
                lr_decay_power=config.lr_decay_power,
                layerwise_lr_decay_power=-1,
                n_transformer_layers=None,
                hvd=hvd,
                use_fp16=config.use_fp16,
                num_accumulation_steps=config.num_accumulation_steps,
                allreduce_post_accumulation=config.allreduce_post_accumulation)
            output_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                loss=model.total_loss,
                train_op=train_op,
                training_hooks=[
                    training_utils.ETAHook(
                        to_log=to_log,
                        n_steps=config.num_train_steps,
                        iterations_per_loop=config.iterations_per_loop,
                        on_tpu=False,
                        log_every=1,
                        is_training=True)
                ])
        elif mode == tf.estimator.ModeKeys.EVAL:
            output_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                loss=model.total_loss,
                eval_metrics=model.eval_metrics,
                evaluation_hooks=[
                    training_utils.ETAHook(
                        to_log=to_log,
                        n_steps=config.num_eval_steps,
                        iterations_per_loop=config.iterations_per_loop,
                        on_tpu=False,
                        log_every=1,
                        is_training=False)
                ])
        else:
            raise ValueError("Only TRAIN and EVAL modes are supported")
        return output_spec
示例#15
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        (total_loss, per_example_loss, logits,
         probabilities) = create_model(bert_config, is_training, input_ids,
                                       input_mask, segment_ids, label_ids,
                                       num_labels, use_one_hot_embeddings)

        tvars = tf.trainable_variables()

        scaffold_fn = None
        if init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = modeling.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:

            train_op = optimization.create_optimizer(total_loss, learning_rate,
                                                     num_train_steps,
                                                     num_warmup_steps, use_tpu)

            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, label_ids, logits):
                #print("###metric_fn.logits:",logits.shape) # (?,80)
                #predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
                #print("###metric_fn.label_ids:",label_ids.shape,";predictions:",predictions.shape) # label_ids: (?,80);predictions:(?,)
                logits_split = tf.split(
                    logits, FLAGS.num_aspects,
                    axis=-1)  # a list. length is num_aspects
                label_ids_split = tf.split(
                    logits, FLAGS.num_aspects,
                    axis=-1)  # a list. length is num_aspects
                accuracy = tf.constant(0.0, dtype=tf.float64)

                for j, logits in enumerate(logits_split):  #
                    #  accuracy = tf.metrics.accuracy(label_ids, predictions)

                    predictions = tf.argmax(
                        logits, axis=-1,
                        output_type=tf.int32)  # should be [batch_size,]
                    label_id_ = tf.cast(tf.argmax(label_ids_split[j], axis=-1),
                                        dtype=tf.int32)
                    print("label_ids_split[j]:", label_ids_split[j],
                          ";predictions:", predictions, ";label_id_:",
                          label_id_)
                    current_accuracy, update_op_accuracy = tf.metrics.accuracy(
                        label_id_, predictions)
                    accuracy += tf.cast(current_accuracy, dtype=tf.float64)
                accuracy = accuracy / tf.constant(FLAGS.num_aspects,
                                                  dtype=tf.float64)
                loss = tf.metrics.mean(per_example_loss)
                return {
                    "eval_accuracy": (accuracy, update_op_accuracy),
                    "eval_loss": loss,
                }

            eval_metrics = (metric_fn, [per_example_loss, label_ids, logits])
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        else:
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode, predictions=probabilities, scaffold_fn=scaffold_fn)
        return output_spec
示例#16
0
  def model_fn(features, labels, mode, params):
    """Build the model for training."""
    model = PretrainingModel(config, features,
                             mode == tf.estimator.ModeKeys.TRAIN)
    utils.log("Model is built!")

    # Load pre-trained weights from checkpoint
    tvars = tf.trainable_variables()

    init_checkpoint = tf.train.latest_checkpoint(config.init_checkpoint)
    utils.log("Using checkpoint", init_checkpoint)
    tvars = tf.trainable_variables()

    initialized_variable_names = {}
    scaffold_fn = None
    if init_checkpoint:
      assignment_map, initialized_variable_names = modeling.get_assignment_map_from_checkpoint(
          tvars, init_checkpoint)
      if config.use_tpu:
        def tpu_scaffold():
          tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
          return tf.train.Scaffold()
        scaffold_fn = tpu_scaffold
      else:
        tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

    utils.log("**** Trainable Variables ****")
    for var in tvars:
      init_string = ""
      if var.name in initialized_variable_names:
        init_string = ", *INIT_FROM_CKPT*"
      utils.log("  name = %s, shape = %s%s", var.name, var.shape,
                      init_string)

    if mode == tf.estimator.ModeKeys.TRAIN:
      train_op = optimization.create_optimizer(
          model.total_loss, config.learning_rate, config.num_train_steps,
          weight_decay_rate=config.weight_decay_rate,
          use_tpu=config.use_tpu,
          warmup_steps=config.num_warmup_steps,
          lr_decay_power=config.lr_decay_power
      )
      output_spec = tf.estimator.tpu.TPUEstimatorSpec(
          mode=mode,
          loss=model.total_loss,
          train_op=train_op,
          scaffold_fn=scaffold_fn,
          training_hooks=[training_utils.ETAHook(
              {} if config.use_tpu else dict(loss=model.total_loss),
              config.num_train_steps, config.iterations_per_loop,
              config.use_tpu)]
      )
    elif mode == tf.estimator.ModeKeys.EVAL:
      output_spec = tf.estimator.tpu.TPUEstimatorSpec(
          mode=mode,
          loss=model.total_loss,
          scaffold_fn=scaffold_fn,
          eval_metrics=model.eval_metrics,
          evaluation_hooks=[training_utils.ETAHook(
              {} if config.use_tpu else dict(loss=model.total_loss),
              config.num_eval_steps, config.iterations_per_loop,
              config.use_tpu, is_training=False)])
    else:
      raise ValueError("Only TRAIN and EVAL modes are supported")
    return output_spec