def model_fn(features, labels, mode, params): """The `model_fn` for TPUEstimator.""" utils.log("Building model...") is_training = (mode == tf.estimator.ModeKeys.TRAIN) model = FinetuningModel(config, tasks, is_training, features, num_train_steps) # Load pre-trained weights from checkpoint init_checkpoint = config.init_checkpoint if pretraining_config is not None: init_checkpoint = tf.train.latest_checkpoint( pretraining_config.model_dir) utils.log("Using checkpoint", init_checkpoint) tvars = tf.trainable_variables() scaffold_fn = None if init_checkpoint: assignment_map, _ = modeling.get_assignment_map_from_checkpoint( tvars, init_checkpoint) if config.use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) # Build model for training or prediction if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer( model.loss, config.learning_rate, num_train_steps, weight_decay_rate=config.weight_decay_rate, use_tpu=config.use_tpu, warmup_proportion=config.warmup_proportion, layerwise_lr_decay_power=config.layerwise_lr_decay, n_transformer_layers=model.bert_config.num_hidden_layers) output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=model.loss, train_op=train_op, scaffold_fn=scaffold_fn, training_hooks=[ training_utils.ETAHook( {} if config.use_tpu else dict(loss=model.loss), num_train_steps, config.iterations_per_loop, config.use_tpu, 10) ]) else: assert mode == tf.estimator.ModeKeys.PREDICT output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=utils.flatten_dict(model.outputs), scaffold_fn=scaffold_fn) utils.log("Building complete") return output_spec
def model_fn(features, labels, mode, params): """The `model_fn` for TPUEstimator.""" utils.log("Building model...") is_training = (mode == tf.estimator.ModeKeys.TRAIN) model = FinetuningModel(config, tasks, is_training, features, num_train_steps) # Load pre-trained weights from checkpoint init_checkpoint = config.init_checkpoint if pretraining_config is not None: init_checkpoint = tf.train.latest_checkpoint( pretraining_config.model_dir) utils.log("Using checkpoint", init_checkpoint) tvars = tf.trainable_variables() scaffold_fn = None initialized_variable_names = {} if init_checkpoint: utils.log("Using checkpoint", init_checkpoint) assignment_map, initialized_variable_names = modeling.get_assignment_map_from_checkpoint( tvars, init_checkpoint) tf.train.init_from_checkpoint(init_checkpoint, assignment_map) utils.log("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" utils.logerr(" name = %s, shape = %s%s", var.name, var.shape, init_string) # Build model for training or prediction if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer( model.loss, config.learning_rate, num_train_steps, weight_decay_rate=config.weight_decay_rate, warmup_proportion=config.warmup_proportion, n_transformer_layers=model.bert_config.num_hidden_layers) output_spec = tf.estimator.EstimatorSpec( mode=mode, loss=model.loss, train_op=train_op, training_hooks=[ training_utils.ETAHook( {} if config.use_tpu else dict(loss=model.loss), num_train_steps, config.iterations_per_loop, config.use_tpu, 10) ]) else: assert mode == tf.estimator.ModeKeys.PREDICT output_spec = tf.estimator.EstimatorSpec( mode=mode, predictions=utils.flatten_dict(model.outputs)) utils.log("Building complete") return output_spec