def model_fn(features, labels, mode, params): """The `model_fn` for TPUEstimator.""" utils.log("Building model...") is_training = (mode == tf.estimator.ModeKeys.TRAIN) model = FinetuningModel(config, tasks, is_training, features, num_train_steps) # Load pre-trained weights from checkpoint init_checkpoint = config.init_checkpoint if pretraining_config is not None: init_checkpoint = tf.train.latest_checkpoint( pretraining_config.model_dir) utils.log("Using checkpoint", init_checkpoint) tvars = tf.trainable_variables() scaffold_fn = None if init_checkpoint: assignment_map, _ = modeling.get_assignment_map_from_checkpoint( tvars, init_checkpoint) if config.use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) # Build model for training or prediction if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer( model.loss, config.learning_rate, num_train_steps, weight_decay_rate=config.weight_decay_rate, use_tpu=config.use_tpu, warmup_proportion=config.warmup_proportion, layerwise_lr_decay_power=config.layerwise_lr_decay, n_transformer_layers=model.bert_config.num_hidden_layers) output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=model.loss, train_op=train_op, scaffold_fn=scaffold_fn, training_hooks=[ training_utils.ETAHook( {} if config.use_tpu else dict(loss=model.loss), num_train_steps, config.iterations_per_loop, config.use_tpu, 10) ]) else: assert mode == tf.estimator.ModeKeys.PREDICT output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=utils.flatten_dict(model.outputs), scaffold_fn=scaffold_fn) utils.log("Building complete") return output_spec
def model_fn(features, labels, mode, params): """The `model_fn` for TPUEstimator.""" utils.log("Building model...") is_training = (mode == tf.estimator.ModeKeys.TRAIN) model = FinetuningModel(config, tasks, is_training, features, num_train_steps) # Load pre-trained weights from checkpoint init_checkpoint = config.init_checkpoint if pretraining_config is not None: init_checkpoint = tf.train.latest_checkpoint( pretraining_config.model_dir) utils.log("Using checkpoint", init_checkpoint) tvars = tf.trainable_variables() scaffold_fn = None initialized_variable_names = {} if init_checkpoint: utils.log("Using checkpoint", init_checkpoint) assignment_map, initialized_variable_names = modeling.get_assignment_map_from_checkpoint( tvars, init_checkpoint) tf.train.init_from_checkpoint(init_checkpoint, assignment_map) utils.log("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" utils.logerr(" name = %s, shape = %s%s", var.name, var.shape, init_string) # Build model for training or prediction if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer( model.loss, config.learning_rate, num_train_steps, weight_decay_rate=config.weight_decay_rate, warmup_proportion=config.warmup_proportion, n_transformer_layers=model.bert_config.num_hidden_layers) output_spec = tf.estimator.EstimatorSpec( mode=mode, loss=model.loss, train_op=train_op, training_hooks=[ training_utils.ETAHook( {} if config.use_tpu else dict(loss=model.loss), num_train_steps, config.iterations_per_loop, config.use_tpu, 10) ]) else: assert mode == tf.estimator.ModeKeys.PREDICT output_spec = tf.estimator.EstimatorSpec( mode=mode, predictions=utils.flatten_dict(model.outputs)) utils.log("Building complete") return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" unique_ids = features["unique_ids"] input_ids = features["input_ids"] input_mask = features["input_mask"] input_type_ids = features["input_type_ids"] model = modeling.BertModel( config=bert_config, is_training=False, input_ids=input_ids, input_mask=input_mask, token_type_ids=input_type_ids, use_one_hot_embeddings=use_one_hot_embeddings) if mode != tf.estimator.ModeKeys.PREDICT: raise ValueError("Only PREDICT modes are supported: %s" % (mode)) tvars = tf.trainable_variables() scaffold_fn = None (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint( tvars, init_checkpoint) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) all_layers = model.get_all_encoder_layers() predictions = { "unique_id": unique_ids, } for (i, layer_index) in enumerate(layer_indexes): predictions["layer_output_%d" % i] = all_layers[layer_index] output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] masked_lm_positions = features["masked_lm_positions"] masked_lm_ids = features["masked_lm_ids"] masked_lm_weights = features["masked_lm_weights"] next_sentence_labels = features["next_sentence_labels"] is_training = (mode == tf.estimator.ModeKeys.TRAIN) model = modeling.BertModel( config=bert_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=use_one_hot_embeddings) (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output( bert_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights) (next_sentence_loss, next_sentence_example_loss, next_sentence_log_probs) = get_next_sentence_output( bert_config, model.get_pooled_output(), next_sentence_labels) total_loss = masked_lm_loss + next_sentence_loss tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: (assignment_map, initialized_variable_names ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer( total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, next_sentence_example_loss, next_sentence_log_probs, next_sentence_labels): """Computes the loss and accuracy of the model.""" masked_lm_log_probs = tf.reshape(masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]]) masked_lm_predictions = tf.argmax( masked_lm_log_probs, axis=-1, output_type=tf.int32) masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) masked_lm_accuracy = tf.metrics.accuracy( labels=masked_lm_ids, predictions=masked_lm_predictions, weights=masked_lm_weights) masked_lm_mean_loss = tf.metrics.mean( values=masked_lm_example_loss, weights=masked_lm_weights) next_sentence_log_probs = tf.reshape( next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]]) next_sentence_predictions = tf.argmax( next_sentence_log_probs, axis=-1, output_type=tf.int32) next_sentence_labels = tf.reshape(next_sentence_labels, [-1]) next_sentence_accuracy = tf.metrics.accuracy( labels=next_sentence_labels, predictions=next_sentence_predictions) next_sentence_mean_loss = tf.metrics.mean( values=next_sentence_example_loss) return { "masked_lm_accuracy": masked_lm_accuracy, "masked_lm_loss": masked_lm_mean_loss, "next_sentence_accuracy": next_sentence_accuracy, "next_sentence_loss": next_sentence_mean_loss, } eval_metrics = (metric_fn, [ masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, next_sentence_example_loss, next_sentence_log_probs, next_sentence_labels ]) output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode)) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument from tensorflow.python.estimator.model_fn import EstimatorSpec tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] label_ids = features["label_ids"] is_training = (mode == tf.estimator.ModeKeys.TRAIN) (total_loss, per_example_loss, logits, probabilities) = self.create_model( bert_config, is_training, input_ids, input_mask, segment_ids, label_ids, num_labels, use_one_hot_embeddings) tvars = tf.trainable_variables() initialized_variable_names = {} if init_checkpoint: (assignment_map, initialized_variable_names) \ = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer( total_loss, learning_rate, num_train_steps, num_warmup_steps, False) output_spec = EstimatorSpec( mode=mode, loss=total_loss, train_op=train_op) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(per_example_loss, label_ids, logits): predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) accuracy = tf.metrics.accuracy(label_ids, predictions) auc = tf.metrics.auc(label_ids, predictions) loss = tf.metrics.mean(per_example_loss) return { "eval_accuracy": accuracy, "eval_auc": auc, "eval_loss": loss, } eval_metrics = metric_fn(per_example_loss, label_ids, logits) output_spec = EstimatorSpec( mode=mode, loss=total_loss, eval_metric_ops=eval_metrics) else: output_spec = EstimatorSpec(mode=mode, predictions=probabilities) return output_spec
def model_fn(features, labels, mode, params): """Build the model for training.""" if config.masking_strategy == pretrain_helpers.ADVERSARIAL_STRATEGY or config.masking_strategy == pretrain_helpers.MIX_ADV_STRATEGY: model = AdversarialPretrainingModel( config, features, mode == tf.estimator.ModeKeys.TRAIN) elif config.masking_strategy == pretrain_helpers.RW_STRATEGY: ratio = [] with open(config.ratio_file, "r") as fin: for line in fin: line = line.strip() if line: tok = line.split() ratio.append(float(tok[1])) model = RatioBasedPretrainingModel( config, features, ratio, mode == tf.estimator.ModeKeys.TRAIN) else: model = PretrainingModel(config, features, mode == tf.estimator.ModeKeys.TRAIN) utils.log("Model is built!") tvars = tf.trainable_variables() initialized_variable_names = {} if config.init_checkpoint: (assignment_map, initialized_variable_names ) = modeling.get_assignment_map_from_checkpoint( tvars, config.init_checkpoint) tf.train.init_from_checkpoint(config.init_checkpoint, assignment_map) utils.log("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" utils.log(" name = %s, shape = %s%s", var.name, var.shape, init_string) if mode == tf.estimator.ModeKeys.TRAIN: if config.masking_strategy == pretrain_helpers.ADVERSARIAL_STRATEGY: student_train_op = optimization.create_optimizer( model.mlm_loss, config.learning_rate, config.num_train_steps, weight_decay_rate=config.weight_decay_rate, use_tpu=config.use_tpu, warmup_steps=config.num_warmup_steps, lr_decay_power=config.lr_decay_power) teacher_train_op = optimization.create_optimizer( model.teacher_loss, config.teacher_learning_rate, config.num_train_steps, lr_decay_power=config.lr_decay_power) train_op = tf.group(student_train_op, teacher_train_op) output_spec = tf.estimator.EstimatorSpec( mode=mode, loss=model.total_loss, train_op=train_op, training_hooks=[ training_utils.ETAHook( dict(loss=model.mlm_loss, teacher_loss=model.teacher_loss, reward=model._baseline), config.num_train_steps, config.iterations_per_loop, config.use_tpu) ]) else: train_op = optimization.create_optimizer( model.total_loss, config.learning_rate, config.num_train_steps, weight_decay_rate=config.weight_decay_rate, use_tpu=config.use_tpu, warmup_steps=config.num_warmup_steps, lr_decay_power=config.lr_decay_power) output_spec = tf.estimator.EstimatorSpec( mode=mode, loss=model.total_loss, train_op=train_op, training_hooks=[ training_utils.ETAHook(dict(loss=model.total_loss), config.num_train_steps, config.iterations_per_loop, config.use_tpu) ]) elif mode == tf.estimator.ModeKeys.EVAL: output_spec = tf.estimator.EstimatorSpec( mode=mode, loss=model.total_loss, eval_metric_ops=model.eval_metrics, evaluation_hooks=[ training_utils.ETAHook(dict(loss=model.total_loss), config.num_eval_steps, config.iterations_per_loop, config.use_tpu, is_training=False) ]) else: raise ValueError("Only TRAIN and EVAL modes are supported") return output_spec
def __init__(self, config: PretrainingConfig, features, is_training, init_checkpoint): # Set up model config self._config = config self._bert_config = training_utils.get_bert_config(config) if config.debug: self._bert_config.num_hidden_layers = 3 self._bert_config.hidden_size = 144 self._bert_config.intermediate_size = 144 * 4 self._bert_config.num_attention_heads = 4 compute_type = modeling.infer_dtype(config.use_fp16) custom_getter = modeling.get_custom_getter(compute_type) with tf.variable_scope(tf.get_variable_scope(), custom_getter=custom_getter): # Mask the input masked_inputs = pretrain_helpers.mask( config, pretrain_data.features_to_inputs(features), config.mask_prob) # Generator embedding_size = self._bert_config.hidden_size if config.embedding_size is None else config.embedding_size if config.uniform_generator: mlm_output = self._get_masked_lm_output(masked_inputs, None) elif config.electra_objective and config.untied_generator: generator = self._build_transformer( name="generator", inputs=masked_inputs, is_training=is_training, use_fp16=config.use_fp16, bert_config=get_generator_config(config, self._bert_config), embedding_size=None if config.untied_generator_embeddings else embedding_size, untied_embeddings=config.untied_generator_embeddings) mlm_output = self._get_masked_lm_output( masked_inputs, generator) else: generator = self._build_transformer( name="electra", inputs=masked_inputs, is_training=is_training, use_fp16=config.use_fp16, embedding_size=embedding_size) mlm_output = self._get_masked_lm_output( masked_inputs, generator) fake_data = self._get_fake_data(masked_inputs, mlm_output.logits) self.mlm_output = mlm_output self.total_loss = config.gen_weight * mlm_output.loss utils.log("Generator is built!") # Discriminator self.disc_output = None if config.electra_objective: discriminator = self._build_transformer( name="electra", inputs=fake_data.inputs, is_training=is_training, use_fp16=config.use_fp16, embedding_size=embedding_size) utils.log("Discriminator is built!") self.disc_output = self._get_discriminator_output( inputs=fake_data.inputs, discriminator=discriminator, labels=fake_data.is_fake_tokens) self.total_loss += config.disc_weight * self.disc_output.loss if init_checkpoint and hvd.rank() == 0: print("Loading checkpoint", init_checkpoint) assignment_map, _ = modeling.get_assignment_map_from_checkpoint( tvars=tf.trainable_variables(), init_checkpoint=init_checkpoint, prefix="") tf.train.init_from_checkpoint(init_checkpoint, assignment_map) # Evaluation eval_fn_inputs = { "input_ids": masked_inputs.input_ids, "masked_lm_preds": mlm_output.preds, "mlm_loss": mlm_output.per_example_loss, "masked_lm_ids": masked_inputs.masked_lm_ids, "masked_lm_weights": masked_inputs.masked_lm_weights, "input_mask": masked_inputs.input_mask } if config.electra_objective: eval_fn_inputs.update({ "disc_loss": self.disc_output.per_example_loss, "disc_labels": self.disc_output.labels, "disc_probs": self.disc_output.probs, "disc_preds": self.disc_output.preds, "sampled_tokids": tf.argmax(fake_data.sampled_tokens, -1, output_type=tf.int32) }) eval_fn_keys = eval_fn_inputs.keys() eval_fn_values = [eval_fn_inputs[k] for k in eval_fn_keys] def metric_fn(*args): """Computes the loss and accuracy of the model.""" d = dict(zip(eval_fn_keys, args)) metrics = dict() metrics["masked_lm_accuracy"] = tf.metrics.accuracy( labels=tf.reshape(d["masked_lm_ids"], [-1]), predictions=tf.reshape(d["masked_lm_preds"], [-1]), weights=tf.reshape(d["masked_lm_weights"], [-1])) metrics["masked_lm_loss"] = tf.metrics.mean( values=tf.reshape(d["mlm_loss"], [-1]), weights=tf.reshape(d["masked_lm_weights"], [-1])) if config.electra_objective: metrics["sampled_masked_lm_accuracy"] = tf.metrics.accuracy( labels=tf.reshape(d["masked_lm_ids"], [-1]), predictions=tf.reshape(d["sampled_tokids"], [-1]), weights=tf.reshape(d["masked_lm_weights"], [-1])) if config.disc_weight > 0: metrics["disc_loss"] = tf.metrics.mean(d["disc_loss"]) metrics["disc_auc"] = tf.metrics.auc( d["disc_labels"] * d["input_mask"], d["disc_probs"] * tf.cast(d["input_mask"], tf.float32)) metrics["disc_accuracy"] = tf.metrics.accuracy( labels=d["disc_labels"], predictions=d["disc_preds"], weights=d["input_mask"]) metrics["disc_precision"] = tf.metrics.accuracy( labels=d["disc_labels"], predictions=d["disc_preds"], weights=d["disc_preds"] * d["input_mask"]) metrics["disc_recall"] = tf.metrics.accuracy( labels=d["disc_labels"], predictions=d["disc_preds"], weights=d["disc_labels"] * d["input_mask"]) return metrics self.eval_metrics = (metric_fn, eval_fn_values)
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] label_ids = features["label_ids"] if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones(label_ids.shape[0], dtype=tf.float32) is_training = (mode == tf.estimator.ModeKeys.TRAIN) (total_loss, per_example_loss, logits, probabilities) = processor.create_model( bert_config, is_training, input_ids, input_mask, segment_ids, label_ids, num_labels, use_one_hot_embeddings) tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: (assignment_map, initialized_variable_names ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer( total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (processor.eval_metric_fn, [per_example_loss, label_ids, logits, input_mask, is_real_example]) output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, predictions={"probabilities": probabilities, 'input_ids': input_ids, 'label_ids': label_ids, 'input_mask': input_mask}, scaffold_fn=scaffold_fn) return output_spec
def __init__(self, config: configure_pretraining.PretrainingConfig, features, is_training): # Set up model config self._config = config self._bert_config = training_utils.get_bert_config(config) if config.debug: self._bert_config.num_hidden_layers = 3 self._bert_config.hidden_size = 144 self._bert_config.intermediate_size = 144 * 4 self._bert_config.num_attention_heads = 4 # Mask the input # masked_inputs = pretrain_helpers.mask( # config, pretrain_data.features_to_inputs(features), config.mask_prob) # tf.logging.error(f"features to inputs: {pretrain_data.features_to_inputs(features)}") # tf.logging.error(f"features: {features}") masked_inputs = pretrain_helpers.mask( config, pretrain_helpers.unmask( pretrain_data.features_to_inputs(features)), config.mask_prob) # Generator embedding_size = (self._bert_config.hidden_size if config.embedding_size is None else config.embedding_size) if config.uniform_generator: mlm_output = self._get_masked_lm_output(masked_inputs, None) elif config.electra_objective and config.untied_generator: generator = self._build_transformer( masked_inputs, is_training, bert_config=get_generator_config(config, self._bert_config), embedding_size=(None if config.untied_generator_embeddings else embedding_size), untied_embeddings=config.untied_generator_embeddings, name="generator") mlm_output = self._get_masked_lm_output(masked_inputs, generator) else: generator = self._build_transformer(masked_inputs, is_training, embedding_size=embedding_size) mlm_output = self._get_masked_lm_output(masked_inputs, generator) fake_data = self._get_fake_data(masked_inputs, mlm_output.logits) self.mlm_output = mlm_output self.total_loss = config.gen_weight * mlm_output.loss # Discriminator disc_output = None if config.electra_objective: discriminator = self._build_transformer( fake_data.inputs, is_training, reuse=not config.untied_generator, embedding_size=embedding_size) disc_output = self._get_discriminator_output( fake_data.inputs, discriminator, fake_data.is_fake_tokens) self.total_loss += config.disc_weight * disc_output.loss tvars = tf.trainable_variables() initialized_variable_names = {} self.scaffold_fn = None if config.init_checkpoint: (assignment_map, initialized_variable_names ) = modeling.get_assignment_map_from_checkpoint( tvars, config.init_checkpoint) if config.use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(config.init_checkpoint, assignment_map) return tf.train.Scaffold() self.scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(config.init_checkpoint, assignment_map) # Evaluation eval_fn_inputs = { "input_ids": masked_inputs.input_ids, "masked_lm_preds": mlm_output.preds, "mlm_loss": mlm_output.per_example_loss, "masked_lm_ids": masked_inputs.masked_lm_ids, "masked_lm_weights": masked_inputs.masked_lm_weights, "input_mask": masked_inputs.input_mask } if config.electra_objective: eval_fn_inputs.update({ "disc_loss": disc_output.per_example_loss, "disc_labels": disc_output.labels, "disc_probs": disc_output.probs, "disc_preds": disc_output.preds, "sampled_tokids": tf.argmax(fake_data.sampled_tokens, -1, output_type=tf.int32) }) eval_fn_keys = eval_fn_inputs.keys() eval_fn_values = [eval_fn_inputs[k] for k in eval_fn_keys] def metric_fn(*args): """Computes the loss and accuracy of the model.""" d = {k: arg for k, arg in zip(eval_fn_keys, args)} metrics = dict() metrics["masked_lm_accuracy"] = tf.metrics.accuracy( labels=tf.reshape(d["masked_lm_ids"], [-1]), predictions=tf.reshape(d["masked_lm_preds"], [-1]), weights=tf.reshape(d["masked_lm_weights"], [-1])) metrics["masked_lm_loss"] = tf.metrics.mean( values=tf.reshape(d["mlm_loss"], [-1]), weights=tf.reshape(d["masked_lm_weights"], [-1])) if config.electra_objective: metrics["sampled_masked_lm_accuracy"] = tf.metrics.accuracy( labels=tf.reshape(d["masked_lm_ids"], [-1]), predictions=tf.reshape(d["sampled_tokids"], [-1]), weights=tf.reshape(d["masked_lm_weights"], [-1])) if config.disc_weight > 0: metrics["disc_loss"] = tf.metrics.mean(d["disc_loss"]) metrics["disc_auc"] = tf.metrics.auc( d["disc_labels"] * d["input_mask"], d["disc_probs"] * tf.cast(d["input_mask"], tf.float32)) metrics["disc_accuracy"] = tf.metrics.accuracy( labels=d["disc_labels"], predictions=d["disc_preds"], weights=d["input_mask"]) metrics["disc_precision"] = tf.metrics.accuracy( labels=d["disc_labels"], predictions=d["disc_preds"], weights=d["disc_preds"] * d["input_mask"]) metrics["disc_recall"] = tf.metrics.accuracy( labels=d["disc_labels"], predictions=d["disc_preds"], weights=d["disc_labels"] * d["input_mask"]) return metrics, scaffold_fn self.eval_metrics = (metric_fn, eval_fn_values)
def model_fn(features, labels, mode, params): """Build the model for training.""" model = PretrainingModel(config, features, mode == tf.estimator.ModeKeys.TRAIN) utils.log("Model is built!") # Load pre-trained weights from checkpoint tvars = tf.trainable_variables() init_checkpoint = tf.train.latest_checkpoint(config.init_checkpoint) utils.log("Using checkpoint", init_checkpoint) tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: assignment_map, initialized_variable_names = modeling.get_assignment_map_from_checkpoint( tvars, init_checkpoint) if config.use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) utils.log("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" utils.log(" name = %s, shape = %s%s", var.name, var.shape, init_string) if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer( model.total_loss, config.learning_rate, config.num_train_steps, weight_decay_rate=config.weight_decay_rate, use_tpu=config.use_tpu, warmup_steps=config.num_warmup_steps, lr_decay_power=config.lr_decay_power ) output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=model.total_loss, train_op=train_op, scaffold_fn=scaffold_fn, training_hooks=[training_utils.ETAHook( {} if config.use_tpu else dict(loss=model.total_loss), config.num_train_steps, config.iterations_per_loop, config.use_tpu)] ) elif mode == tf.estimator.ModeKeys.EVAL: output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=model.total_loss, scaffold_fn=scaffold_fn, eval_metrics=model.eval_metrics, evaluation_hooks=[training_utils.ETAHook( {} if config.use_tpu else dict(loss=model.total_loss), config.num_eval_steps, config.iterations_per_loop, config.use_tpu, is_training=False)]) else: raise ValueError("Only TRAIN and EVAL modes are supported") return output_spec