def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" logging.info("*** Model: Params ***") for name in sorted(params.keys()): logging.info(" %s = %s", name, params[name]) logging.info("*** Model: Features ***") for name in sorted(features.keys()): logging.info(" name = %s, shape = %s", name, features[name].shape) model = modeling.ReadItTwiceBertModel( config=model_config, use_one_hot_embeddings=use_one_hot_embeddings) span_prediction_layer = modeling.SpanPredictionHead( intermediate_size=model_config.intermediate_size, dropout_rate=model_config.hidden_dropout_prob) # [batch_size, main_seq_length] token_ids = features["token_ids"] main_seq_length = tf.shape(token_ids)[1] block_ids = features["block_ids"] block_pos = features["block_pos"] annotation_begins = features.get("entity_annotation_begins") annotation_ends = features.get("entity_annotation_ends") annotation_labels = features.get("entity_annotation_labels") # Do not attend padding tokens # [batch_size, main_seq_length, main_seq_length] att_mask = tf.tile( tf.expand_dims(tf.not_equal(token_ids, padding_token_id), 1), [1, main_seq_length, 1]) att_mask = tf.cast(att_mask, dtype=tf.int32) main_output = model( token_ids=token_ids, training=(mode == tf.estimator.ModeKeys.TRAIN), block_ids=block_ids, block_pos=block_pos, att_mask=att_mask, annotation_begins=annotation_begins, annotation_ends=annotation_ends, annotation_labels=annotation_labels, enable_side_inputs=enable_side_inputs, num_replicas_concat=num_replicas_concat, cross_block_attention_mode=cross_block_attention_mode) span_logits = span_prediction_layer( hidden_states=main_output.final_hidden_states, token_ids=token_ids, padding_token_id=padding_token_id, ignore_prefix_length=features["prefix_length"], training=(mode == tf.estimator.ModeKeys.TRAIN)) is_summary_loss_enabled = (mode == tf.estimator.ModeKeys.TRAIN and summary_loss_weight is not None and summary_loss_weight > 0) if is_summary_loss_enabled: logging.info("Using summary prediction loss with weight %.3f", summary_loss_weight) summary_token_ids = features["summary_token_ids"] summary_labels = tf.roll(summary_token_ids, shift=-1, axis=1) decoder = modeling.ReadItTwiceDecoderModel( config=model_config, num_layers_override=summary_num_layers, num_cross_attention_heads=summary_num_cross_attention_heads, enable_default_side_input=summary_enable_default_side_input, use_one_hot_embeddings=use_one_hot_embeddings) summary_token_logits = decoder( token_ids=summary_token_ids, side_input=main_output.global_summary.states, token2side_input_att_mask=modeling.get_cross_block_att( block_ids, block_pos, main_output.global_summary.block_ids, main_output.global_summary.block_pos, cross_block_attention_mode="doc"), training=True) language_model_loss_fn = losses.LanguageModelLoss( decoder.get_token_embedding_table(), hidden_size=model_config.hidden_size) language_model_loss = language_model_loss_fn( summary_token_logits, summary_labels, padding_token_id=padding_token_id).loss else: language_model_loss = None tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: (assignment_map, initialized_variable_names ) = checkpoint_utils.get_assignment_map_from_checkpoint( tvars, init_checkpoint) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: host_inputs = dict() span_prediction_loss = losses.BatchSpanCrossEntropyLoss() qa_loss = span_prediction_loss( logits=span_logits, annotation_begins=features["answer_annotation_begins"], annotation_ends=features["answer_annotation_ends"], annotation_labels=features["answer_annotation_labels"], block_ids=block_ids, num_replicas=num_replicas_concat, eps=1e-5) host_inputs["train_metrics/qa_loss"] = tf.expand_dims(qa_loss, 0) if language_model_loss is not None: total_loss = ( 1.0 / (1.0 + summary_loss_weight) * qa_loss + summary_loss_weight / (1.0 + summary_loss_weight) * language_model_loss) host_inputs["train_metrics/summary_lm_loss"] = tf.expand_dims( language_model_loss, 0) else: total_loss = qa_loss # Add regularization losses. if model.losses: total_loss += tf.math.add_n(model.losses) train_op = optimization.create_optimizer(total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu, optimizer, poly_power, start_warmup_step, learning_rate_schedule, reduce_loss_sum=True) host_inputs.update({ "global_step": tf.expand_dims(tf.train.get_or_create_global_step(), 0), "train_metrics/loss": tf.expand_dims(total_loss, 0), }) host_call = (functools.partial(record_summary_host_fn, metrics_dir=os.path.join( FLAGS.output_dir, "train_metrics")), host_inputs) output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn, host_call=host_call) elif mode == tf.estimator.ModeKeys.PREDICT: begin_logits_values, begin_logits_indices = tf.math.top_k( span_logits[:, :, 0], k=nbest_logits_for_eval, ) end_logits_values, end_logits_indices = tf.math.top_k( span_logits[:, :, 1], k=nbest_logits_for_eval, ) predictions = { "block_ids": tf.identity(block_ids), "begin_logits_values": begin_logits_values, "begin_logits_indices": begin_logits_indices, "end_logits_values": end_logits_values, "end_logits_indices": end_logits_indices, "token_ids": tf.identity(token_ids), } output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) else: raise ValueError("Only TRAIN and PREDICT modes is supported: %s" % (mode)) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" logging.info("*** Model: Params ***") for name in sorted(params.keys()): logging.info(" %s = %s", name, params[name]) logging.info("*** Model: Features ***") for name in sorted(features.keys()): logging.info(" name = %s, shape = %s", name, features[name].shape) model = modeling.ReadItTwiceBertModel( config=model_config, use_one_hot_embeddings=use_one_hot_embeddings) span_prediction_layer = modeling.SpanPredictionHead( intermediate_size=model_config.intermediate_size, dropout_rate=model_config.hidden_dropout_prob) # [batch_size, main_seq_length] token_ids = features["token_ids"] main_seq_length = tf.shape(token_ids)[1] block_ids = features["block_ids"] block_pos = features["block_pos"] answer_type = features["answer_type"] supporting_fact = features["is_supporting_fact"] annotation_begins = features.get("entity_annotation_begins") annotation_ends = features.get("entity_annotation_ends") annotation_labels = features.get("entity_annotation_labels") # Do not attend padding tokens # [batch_size, main_seq_length, main_seq_length] att_mask = tf.tile( tf.expand_dims(tf.not_equal(token_ids, padding_token_id), 1), [1, main_seq_length, 1]) att_mask = tf.cast(att_mask, dtype=tf.int32) main_output = model( token_ids=token_ids, training=(mode == tf.estimator.ModeKeys.TRAIN), block_ids=block_ids, block_pos=block_pos, att_mask=att_mask, annotation_begins=annotation_begins, annotation_ends=annotation_ends, annotation_labels=annotation_labels, enable_side_inputs=enable_side_inputs, num_replicas_concat=num_replicas_concat, cross_block_attention_mode=cross_block_attention_mode) span_logits = span_prediction_layer( hidden_states=main_output.final_hidden_states, token_ids=token_ids, padding_token_id=padding_token_id, ignore_prefix_length=features["prefix_length"], training=(mode == tf.estimator.ModeKeys.TRAIN)) # The "pooler" converts the encoded sequence tensor of shape # [batch_size, seq_length, hidden_size] to a tensor of shape # [batch_size, hidden_size]. This is necessary for segment-level # (or segment-pair-level) classification tasks where we need a fixed # dimensional representation of the segment. with tf.variable_scope("pooler"): # We "pool" the model by simply taking the hidden state corresponding # to the first token. We assume that this has been pre-trained first_token_tensor = tf.squeeze( main_output.final_hidden_states[:, 0:1, :], axis=1) pooled_output = tf.layers.dense( first_token_tensor, model_config.hidden_size, activation=tf.tanh, kernel_initializer=tf.truncated_normal_initializer( stddev=model_config.initializer_range)) yesno_logits = yesno_model(pooled_output) supporting_fact_logits = supporting_fact_model(pooled_output) tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: (assignment_map, initialized_variable_names ) = checkpoint_utils.get_assignment_map_from_checkpoint( tvars, init_checkpoint) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: host_inputs = dict() span_prediction_loss = losses.BatchSpanCrossEntropyLoss() total_loss = 0 qa_loss = span_prediction_loss( logits=span_logits, annotation_begins=features["answer_annotation_begins"], annotation_ends=features["answer_annotation_ends"], annotation_labels=features["answer_annotation_labels"], block_ids=block_ids, num_replicas=num_replicas_concat, eps=1e-5) host_inputs["train_metrics/qa_loss"] = tf.expand_dims(qa_loss, 0) total_loss += qa_loss # example_mask = tf.cast(tf.not_equal(block_ids, 0), tf.float32) # yesno_loss = compute_pooled_loss(yesno_logits, answer_type, 3, # example_mask) # supporting_fact_loss = compute_supporting_facts_loss( # supporting_fact_logits, supporting_fact, example_mask) hotpot_qa_loss = hotpot_qa_losses.BatchSpanCrossEntropyLoss() yesno_loss, supporting_fact_loss = hotpot_qa_loss( yesno_logits, answer_type, supporting_fact_logits, supporting_fact, block_ids, eps=1e-5) host_inputs["train_metrics/yesno_loss"] = tf.expand_dims(yesno_loss, 0) total_loss += yesno_loss host_inputs["train_metrics/supporting_fact_loss"] = tf.expand_dims( supporting_fact_loss, 0) total_loss += supporting_fact_loss # Add regularization losses. if model.losses: total_loss += tf.math.add_n(model.losses) train_op = optimization.create_optimizer( total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu, optimizer, poly_power, start_warmup_step, learning_rate_schedule, reduce_loss_sum=True) host_inputs.update({ "global_step": tf.expand_dims(tf.train.get_or_create_global_step(), 0), "train_metrics/loss": tf.expand_dims(total_loss, 0), }) host_call = (functools.partial( record_summary_host_fn, metrics_dir=os.path.join(FLAGS.output_dir, "train_metrics")), host_inputs) output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn, host_call=host_call) elif mode == tf.estimator.ModeKeys.PREDICT: begin_logits_values, begin_logits_indices = tf.math.top_k( span_logits[:, :, 0], k=nbest_logits_for_eval, ) end_logits_values, end_logits_indices = tf.math.top_k( span_logits[:, :, 1], k=nbest_logits_for_eval, ) predictions = { "block_ids": tf.identity(block_ids), "begin_logits_values": begin_logits_values, "begin_logits_indices": begin_logits_indices, "end_logits_values": end_logits_values, "end_logits_indices": end_logits_indices, "token_ids": tf.identity(token_ids), "answer_type": answer_type, "yesno_logits": yesno_logits, "supporting_fact_logits": supporting_fact_logits, "is_supporting_fact": supporting_fact, } output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) else: raise ValueError("Only TRAIN and PREDICT modes is supported: %s" % mode) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" logging.info("*** Model: Params ***") for name in sorted(params.keys()): logging.info(" %s = %s", name, params[name]) logging.info("*** Model: Features ***") for name in sorted(features.keys()): logging.info(" name = %s, shape = %s", name, features[name].shape) model = modeling.ReadItTwiceBertModel( config=model_config, use_one_hot_embeddings=use_one_hot_embeddings) # [batch_size, main_seq_length] token_ids = features["token_ids"] batch_size = tf.shape(token_ids)[0] main_seq_length = tf.shape(token_ids)[1] block_ids = features["block_ids"] block_pos = features["block_pos"] annotation_begins = features.get("annotation_begins") annotation_ends = features.get("annotation_ends") annotation_labels = features.get("annotation_labels") # Do not attend padding tokens # [batch_size, main_seq_length, main_seq_length] att_mask = tf.tile( tf.expand_dims(tf.not_equal(token_ids, padding_token_id), 1), [1, main_seq_length, 1]) att_mask = tf.cast(att_mask, dtype=tf.int32) main_output = model( token_ids=token_ids, training=(mode == tf.estimator.ModeKeys.TRAIN), block_ids=block_ids, block_pos=block_pos, att_mask=att_mask, annotation_begins=annotation_begins, annotation_ends=annotation_ends, annotation_labels=annotation_labels, enable_side_inputs=enable_side_inputs, num_replicas_concat=num_replicas_concat, cross_block_attention_mode=cross_block_attention_mode) mlm_loss_fn = losses.LanguageModelLoss( model.get_token_embedding_table(), hidden_size=model_config.hidden_size, name="mlm_loss") mlm_loss_output = mlm_loss_fn( input_tensor=main_output.final_hidden_states, label_ids=features["masked_lm_ids"], positions=features["masked_lm_positions"], label_weights=features["masked_lm_weights"], mlm_is_entity_mask=features.get("mlm_is_entity_mask"), mlm_is_not_entity_mask=features.get("mlm_is_not_entity_mask"), padding_token_id=padding_token_id) mlm_loss = mlm_loss_output.loss loss_to_log = dict(mlm_loss=tf.expand_dims(mlm_loss, 0)) loss_weight_denominator = 1.0 + sum(extra_loss.values()) total_loss = mlm_loss * (1.0 / loss_weight_denominator) for loss_name, loss_weight in extra_loss.items(): logging.info("EXTRA LOSS: %s with weight %.2f", loss_name, loss_weight / loss_weight_denominator) if model_config.summary_mode == "entity": # entity label "1" corresponds to unknown entity # there is no need to compute coreferense resolution loss # for these unknown entities. labels_weight = tf.cast( tf.logical_and( tf.not_equal( tf.expand_dims(main_output.local_summary.labels, 1), 1), tf.not_equal( tf.expand_dims(main_output.global_summary.labels, 0), 1)), tf.float32) else: labels_weight = None if loss_name == "sdp": loss_fn = losses.BatchCoreferenceResolutionLoss( apply_linear_layer=False) loss_value = loss_fn( main_output.local_summary.states, main_output.local_summary.labels, main_output.global_summary.states, main_output.global_summary.labels, labels_weight=labels_weight) elif loss_name == "sdp_linear": loss_fn = losses.BatchCoreferenceResolutionLoss(apply_linear_layer=True) loss_value = loss_fn( main_output.local_summary.states, main_output.local_summary.labels, main_output.global_summary.states, main_output.global_summary.labels, labels_weight=labels_weight) elif loss_name == "spp_linear": loss_fn = losses.BatchCoreferenceResolutionLoss(apply_linear_layer=True) # Positive examples are blocks which go one after another in the # original document. labels_mask = tf.less_equal( tf.abs( tf.expand_dims(main_output.local_summary.block_pos, 1) - tf.expand_dims(main_output.global_summary.block_pos, 0)), 1) loss_value = loss_fn( main_output.local_summary.states, main_output.local_summary.labels, main_output.global_summary.states, main_output.global_summary.labels, labels_mask=labels_mask, labels_weight=labels_weight) elif loss_name == "lm": token_labels = tf.roll(token_ids, shift=-1, axis=1) # [batch_size, global_batch_size] token2side_input_att_mask = modeling.get_cross_block_att( block_ids, block_pos, main_output.global_summary.block_ids, main_output.global_summary.block_pos, cross_block_attention_mode=cross_block_attention_mode, cast_to_int32=False) # We want to exclude the summary of the block itself # from decoder side input. As a proxy for this, we use block_ids AND # block_pos. samples_are_the_same = tf.logical_and( tf.equal( tf.expand_dims(block_ids, 1), tf.expand_dims(main_output.global_summary.block_ids, 0)), tf.equal( tf.expand_dims(block_pos, 1), tf.expand_dims(main_output.global_summary.block_pos, 0))) token2side_input_att_mask = tf.stop_gradient( tf.cast( tf.logical_and(token2side_input_att_mask, tf.logical_not(samples_are_the_same)), dtype=tf.int32)) decoder = modeling.ReadItTwiceDecoderModel( config=model_config, num_layers_override=summary_num_layers, num_cross_attention_heads=summary_num_cross_attention_heads, enable_default_side_input=summary_enable_default_side_input, use_one_hot_embeddings=use_one_hot_embeddings) summary_token_logits = decoder( token_ids=token_ids, side_input=main_output.global_summary.states, token2side_input_att_mask=token2side_input_att_mask, training=True) language_model_loss_fn = losses.LanguageModelLoss( decoder.get_token_embedding_table(), hidden_size=model_config.hidden_size) # We don't penalize the first and last 32 tokens, so the model does not # have incentive to memoize tokens at the border of blocks. labels_weights = tf.concat([ tf.zeros([batch_size, 32], dtype=tf.bool), tf.ones([batch_size, main_seq_length - 32 * 2], dtype=tf.bool), tf.zeros([batch_size, 32], dtype=tf.bool) ], axis=1) labels_weights = tf.logical_and( labels_weights, tf.not_equal(token_labels, padding_token_id)) labels_weights = tf.stop_gradient( tf.cast(labels_weights, dtype=tf.float32)) loss_value = language_model_loss_fn( summary_token_logits, token_labels, label_weights=labels_weights).loss else: raise ValueError("Unknown extra loss: {}".format(loss_name)) loss_to_log[loss_name] = tf.expand_dims(loss_value, 0) total_loss += loss_value * (loss_weight / loss_weight_denominator) if model.losses: total_loss += tf.math.add_n(model.losses) tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: (assignment_map, initialized_variable_names ) = checkpoint_utils.get_assignment_map_from_checkpoint( tvars, init_checkpoint) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) metric_fn_tensors = dict( mlm_loss_per_sample=mlm_loss_output.mlm_loss_per_sample, mlm_accuracy_per_sample=mlm_loss_output.mlm_accuracy_per_sample, mlm_weight_per_sample=mlm_loss_output.mlm_weight_per_sample, mlm_loss_per_entity_sample=mlm_loss_output.mlm_loss_per_entity_sample, mlm_accuracy_per_entity_sample=mlm_loss_output .mlm_accuracy_per_entity_sample, mlm_weight_per_entity_sample=mlm_loss_output .mlm_weight_per_entity_sample, mlm_loss_per_non_entity_sample=mlm_loss_output .mlm_loss_per_non_entity_sample, mlm_accuracy_per_non_entity_sample=mlm_loss_output .mlm_accuracy_per_non_entity_sample, mlm_weight_per_non_entity_sample=mlm_loss_output .mlm_weight_per_non_entity_sample, block_ids=block_ids) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer( total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu, optimizer, poly_power, start_warmup_step, learning_rate_schedule) metric_fn_tensors.update({ "global_step": tf.expand_dims(tf.train.get_or_create_global_step(), 0), "loss": tf.expand_dims(total_loss, 0), }) metric_fn_tensors.update(loss_to_log) host_call = (functools.partial( record_summary_host_fn, metrics_dir=os.path.join(FLAGS.output_dir, "train_metrics"), metrics_name=metrics_name or "train_metrics"), metric_fn_tensors) output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn, host_call=host_call) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (functools.partial( metric_utils.masked_lm_metrics, is_train=False, metrics_name=metrics_name or "eval_metrics"), metric_fn_tensors) output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: raise ValueError("Only TRAIN and EVAL modes are supported: %s" % mode) return output_spec