def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" logging.info("*** Model: Params ***") for name in sorted(params.keys()): logging.info(" %s = %s", name, params[name]) logging.info("*** Model: Features ***") for name in sorted(features.keys()): logging.info(" name = %s, shape = %s", name, features[name].shape) model = modeling.ReadItTwiceBertModel( config=model_config, use_one_hot_embeddings=use_one_hot_embeddings) span_prediction_layer = modeling.SpanPredictionHead( intermediate_size=model_config.intermediate_size, dropout_rate=model_config.hidden_dropout_prob) # [batch_size, main_seq_length] token_ids = features["token_ids"] main_seq_length = tf.shape(token_ids)[1] block_ids = features["block_ids"] block_pos = features["block_pos"] answer_type = features["answer_type"] supporting_fact = features["is_supporting_fact"] annotation_begins = features.get("entity_annotation_begins") annotation_ends = features.get("entity_annotation_ends") annotation_labels = features.get("entity_annotation_labels") # Do not attend padding tokens # [batch_size, main_seq_length, main_seq_length] att_mask = tf.tile( tf.expand_dims(tf.not_equal(token_ids, padding_token_id), 1), [1, main_seq_length, 1]) att_mask = tf.cast(att_mask, dtype=tf.int32) main_output = model( token_ids=token_ids, training=(mode == tf.estimator.ModeKeys.TRAIN), block_ids=block_ids, block_pos=block_pos, att_mask=att_mask, annotation_begins=annotation_begins, annotation_ends=annotation_ends, annotation_labels=annotation_labels, enable_side_inputs=enable_side_inputs, num_replicas_concat=num_replicas_concat, cross_block_attention_mode=cross_block_attention_mode) span_logits = span_prediction_layer( hidden_states=main_output.final_hidden_states, token_ids=token_ids, padding_token_id=padding_token_id, ignore_prefix_length=features["prefix_length"], training=(mode == tf.estimator.ModeKeys.TRAIN)) # The "pooler" converts the encoded sequence tensor of shape # [batch_size, seq_length, hidden_size] to a tensor of shape # [batch_size, hidden_size]. This is necessary for segment-level # (or segment-pair-level) classification tasks where we need a fixed # dimensional representation of the segment. with tf.variable_scope("pooler"): # We "pool" the model by simply taking the hidden state corresponding # to the first token. We assume that this has been pre-trained first_token_tensor = tf.squeeze( main_output.final_hidden_states[:, 0:1, :], axis=1) pooled_output = tf.layers.dense( first_token_tensor, model_config.hidden_size, activation=tf.tanh, kernel_initializer=tf.truncated_normal_initializer( stddev=model_config.initializer_range)) yesno_logits = yesno_model(pooled_output) supporting_fact_logits = supporting_fact_model(pooled_output) tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: (assignment_map, initialized_variable_names ) = checkpoint_utils.get_assignment_map_from_checkpoint( tvars, init_checkpoint) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: host_inputs = dict() span_prediction_loss = losses.BatchSpanCrossEntropyLoss() total_loss = 0 qa_loss = span_prediction_loss( logits=span_logits, annotation_begins=features["answer_annotation_begins"], annotation_ends=features["answer_annotation_ends"], annotation_labels=features["answer_annotation_labels"], block_ids=block_ids, num_replicas=num_replicas_concat, eps=1e-5) host_inputs["train_metrics/qa_loss"] = tf.expand_dims(qa_loss, 0) total_loss += qa_loss # example_mask = tf.cast(tf.not_equal(block_ids, 0), tf.float32) # yesno_loss = compute_pooled_loss(yesno_logits, answer_type, 3, # example_mask) # supporting_fact_loss = compute_supporting_facts_loss( # supporting_fact_logits, supporting_fact, example_mask) hotpot_qa_loss = hotpot_qa_losses.BatchSpanCrossEntropyLoss() yesno_loss, supporting_fact_loss = hotpot_qa_loss( yesno_logits, answer_type, supporting_fact_logits, supporting_fact, block_ids, eps=1e-5) host_inputs["train_metrics/yesno_loss"] = tf.expand_dims(yesno_loss, 0) total_loss += yesno_loss host_inputs["train_metrics/supporting_fact_loss"] = tf.expand_dims( supporting_fact_loss, 0) total_loss += supporting_fact_loss # Add regularization losses. if model.losses: total_loss += tf.math.add_n(model.losses) train_op = optimization.create_optimizer( total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu, optimizer, poly_power, start_warmup_step, learning_rate_schedule, reduce_loss_sum=True) host_inputs.update({ "global_step": tf.expand_dims(tf.train.get_or_create_global_step(), 0), "train_metrics/loss": tf.expand_dims(total_loss, 0), }) host_call = (functools.partial( record_summary_host_fn, metrics_dir=os.path.join(FLAGS.output_dir, "train_metrics")), host_inputs) output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn, host_call=host_call) elif mode == tf.estimator.ModeKeys.PREDICT: begin_logits_values, begin_logits_indices = tf.math.top_k( span_logits[:, :, 0], k=nbest_logits_for_eval, ) end_logits_values, end_logits_indices = tf.math.top_k( span_logits[:, :, 1], k=nbest_logits_for_eval, ) predictions = { "block_ids": tf.identity(block_ids), "begin_logits_values": begin_logits_values, "begin_logits_indices": begin_logits_indices, "end_logits_values": end_logits_values, "end_logits_indices": end_logits_indices, "token_ids": tf.identity(token_ids), "answer_type": answer_type, "yesno_logits": yesno_logits, "supporting_fact_logits": supporting_fact_logits, "is_supporting_fact": supporting_fact, } output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) else: raise ValueError("Only TRAIN and PREDICT modes is supported: %s" % mode) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" logging.info("*** Model: Params ***") for name in sorted(params.keys()): logging.info(" %s = %s", name, params[name]) logging.info("*** Model: Features ***") for name in sorted(features.keys()): logging.info(" name = %s, shape = %s", name, features[name].shape) model = modeling.ReadItTwiceBertModel( config=model_config, use_one_hot_embeddings=use_one_hot_embeddings) span_prediction_layer = modeling.SpanPredictionHead( intermediate_size=model_config.intermediate_size, dropout_rate=model_config.hidden_dropout_prob) # [batch_size, main_seq_length] token_ids = features["token_ids"] main_seq_length = tf.shape(token_ids)[1] block_ids = features["block_ids"] block_pos = features["block_pos"] annotation_begins = features.get("entity_annotation_begins") annotation_ends = features.get("entity_annotation_ends") annotation_labels = features.get("entity_annotation_labels") # Do not attend padding tokens # [batch_size, main_seq_length, main_seq_length] att_mask = tf.tile( tf.expand_dims(tf.not_equal(token_ids, padding_token_id), 1), [1, main_seq_length, 1]) att_mask = tf.cast(att_mask, dtype=tf.int32) main_output = model( token_ids=token_ids, training=(mode == tf.estimator.ModeKeys.TRAIN), block_ids=block_ids, block_pos=block_pos, att_mask=att_mask, annotation_begins=annotation_begins, annotation_ends=annotation_ends, annotation_labels=annotation_labels, enable_side_inputs=enable_side_inputs, num_replicas_concat=num_replicas_concat, cross_block_attention_mode=cross_block_attention_mode) span_logits = span_prediction_layer( hidden_states=main_output.final_hidden_states, token_ids=token_ids, padding_token_id=padding_token_id, ignore_prefix_length=features["prefix_length"], training=(mode == tf.estimator.ModeKeys.TRAIN)) is_summary_loss_enabled = (mode == tf.estimator.ModeKeys.TRAIN and summary_loss_weight is not None and summary_loss_weight > 0) if is_summary_loss_enabled: logging.info("Using summary prediction loss with weight %.3f", summary_loss_weight) summary_token_ids = features["summary_token_ids"] summary_labels = tf.roll(summary_token_ids, shift=-1, axis=1) decoder = modeling.ReadItTwiceDecoderModel( config=model_config, num_layers_override=summary_num_layers, num_cross_attention_heads=summary_num_cross_attention_heads, enable_default_side_input=summary_enable_default_side_input, use_one_hot_embeddings=use_one_hot_embeddings) summary_token_logits = decoder( token_ids=summary_token_ids, side_input=main_output.global_summary.states, token2side_input_att_mask=modeling.get_cross_block_att( block_ids, block_pos, main_output.global_summary.block_ids, main_output.global_summary.block_pos, cross_block_attention_mode="doc"), training=True) language_model_loss_fn = losses.LanguageModelLoss( decoder.get_token_embedding_table(), hidden_size=model_config.hidden_size) language_model_loss = language_model_loss_fn( summary_token_logits, summary_labels, padding_token_id=padding_token_id).loss else: language_model_loss = None tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: (assignment_map, initialized_variable_names ) = checkpoint_utils.get_assignment_map_from_checkpoint( tvars, init_checkpoint) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: host_inputs = dict() span_prediction_loss = losses.BatchSpanCrossEntropyLoss() qa_loss = span_prediction_loss( logits=span_logits, annotation_begins=features["answer_annotation_begins"], annotation_ends=features["answer_annotation_ends"], annotation_labels=features["answer_annotation_labels"], block_ids=block_ids, num_replicas=num_replicas_concat, eps=1e-5) host_inputs["train_metrics/qa_loss"] = tf.expand_dims(qa_loss, 0) if language_model_loss is not None: total_loss = ( 1.0 / (1.0 + summary_loss_weight) * qa_loss + summary_loss_weight / (1.0 + summary_loss_weight) * language_model_loss) host_inputs["train_metrics/summary_lm_loss"] = tf.expand_dims( language_model_loss, 0) else: total_loss = qa_loss # Add regularization losses. if model.losses: total_loss += tf.math.add_n(model.losses) train_op = optimization.create_optimizer(total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu, optimizer, poly_power, start_warmup_step, learning_rate_schedule, reduce_loss_sum=True) host_inputs.update({ "global_step": tf.expand_dims(tf.train.get_or_create_global_step(), 0), "train_metrics/loss": tf.expand_dims(total_loss, 0), }) host_call = (functools.partial(record_summary_host_fn, metrics_dir=os.path.join( FLAGS.output_dir, "train_metrics")), host_inputs) output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn, host_call=host_call) elif mode == tf.estimator.ModeKeys.PREDICT: begin_logits_values, begin_logits_indices = tf.math.top_k( span_logits[:, :, 0], k=nbest_logits_for_eval, ) end_logits_values, end_logits_indices = tf.math.top_k( span_logits[:, :, 1], k=nbest_logits_for_eval, ) predictions = { "block_ids": tf.identity(block_ids), "begin_logits_values": begin_logits_values, "begin_logits_indices": begin_logits_indices, "end_logits_values": end_logits_values, "end_logits_indices": end_logits_indices, "token_ids": tf.identity(token_ids), } output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) else: raise ValueError("Only TRAIN and PREDICT modes is supported: %s" % (mode)) return output_spec