def test_cross_entropy_loss_unique_block_ids(self, batch_size, seq_length, num_annotations): np.random.seed(31415) logits = np.random.random((batch_size, seq_length, 2)) logits = (logits - 0.5) * 100 logits = logits.astype(np.float32) annotation_begins = np.stack([ np.random.choice(seq_length, size=num_annotations, replace=False) for _ in range(batch_size) ]) annotation_ends = np.stack([ np.random.choice(seq_length, size=num_annotations, replace=False) for _ in range(batch_size) ]) one_hot_labels = np.zeros((batch_size, seq_length, 2), dtype=np.float32) for i in range(batch_size): one_hot_labels[i, annotation_begins[i], 0] = 1 one_hot_labels[i, annotation_ends[i], 1] = 1 logits_tf = tf.compat.v1.placeholder_with_default( logits, shape=[None, None, 2]) block_ids = tf.range(batch_size) annotation_begins_tf = tf.compat.v1.placeholder_with_default( annotation_begins, shape=[None, None]) annotation_ends_tf = tf.compat.v1.placeholder_with_default( annotation_ends, shape=[None, None]) annotation_labels = tf.ones((batch_size, num_annotations), dtype=tf.float32) one_hot_labels_tf = tf.compat.v1.placeholder_with_default( one_hot_labels, shape=[None, None, 2]) loss_layer = losses.BatchSpanCrossEntropyLoss() init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) self.evaluate(init_op) actual_loss = loss_layer(logits_tf, annotation_begins_tf, annotation_ends_tf, annotation_labels, block_ids) logits_masked = logits - tf.cast(one_hot_labels_tf < 0.5, tf.float32) * 1e6 or_cross_entropy = (tf.math.reduce_logsumexp(logits_tf, axis=-2) - tf.math.reduce_logsumexp(logits_masked, axis=-2)) expected_loss = tf.math.reduce_sum(or_cross_entropy) actual_loss_value, expected_loss_value = self.evaluate( [actual_loss, expected_loss]) self.assertNear(actual_loss_value, expected_loss_value, err=1e-4)
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" logging.info("*** Model: Params ***") for name in sorted(params.keys()): logging.info(" %s = %s", name, params[name]) logging.info("*** Model: Features ***") for name in sorted(features.keys()): logging.info(" name = %s, shape = %s", name, features[name].shape) model = modeling.ReadItTwiceBertModel( config=model_config, use_one_hot_embeddings=use_one_hot_embeddings) span_prediction_layer = modeling.SpanPredictionHead( intermediate_size=model_config.intermediate_size, dropout_rate=model_config.hidden_dropout_prob) # [batch_size, main_seq_length] token_ids = features["token_ids"] main_seq_length = tf.shape(token_ids)[1] block_ids = features["block_ids"] block_pos = features["block_pos"] answer_type = features["answer_type"] supporting_fact = features["is_supporting_fact"] annotation_begins = features.get("entity_annotation_begins") annotation_ends = features.get("entity_annotation_ends") annotation_labels = features.get("entity_annotation_labels") # Do not attend padding tokens # [batch_size, main_seq_length, main_seq_length] att_mask = tf.tile( tf.expand_dims(tf.not_equal(token_ids, padding_token_id), 1), [1, main_seq_length, 1]) att_mask = tf.cast(att_mask, dtype=tf.int32) main_output = model( token_ids=token_ids, training=(mode == tf.estimator.ModeKeys.TRAIN), block_ids=block_ids, block_pos=block_pos, att_mask=att_mask, annotation_begins=annotation_begins, annotation_ends=annotation_ends, annotation_labels=annotation_labels, enable_side_inputs=enable_side_inputs, num_replicas_concat=num_replicas_concat, cross_block_attention_mode=cross_block_attention_mode) span_logits = span_prediction_layer( hidden_states=main_output.final_hidden_states, token_ids=token_ids, padding_token_id=padding_token_id, ignore_prefix_length=features["prefix_length"], training=(mode == tf.estimator.ModeKeys.TRAIN)) # The "pooler" converts the encoded sequence tensor of shape # [batch_size, seq_length, hidden_size] to a tensor of shape # [batch_size, hidden_size]. This is necessary for segment-level # (or segment-pair-level) classification tasks where we need a fixed # dimensional representation of the segment. with tf.variable_scope("pooler"): # We "pool" the model by simply taking the hidden state corresponding # to the first token. We assume that this has been pre-trained first_token_tensor = tf.squeeze( main_output.final_hidden_states[:, 0:1, :], axis=1) pooled_output = tf.layers.dense( first_token_tensor, model_config.hidden_size, activation=tf.tanh, kernel_initializer=tf.truncated_normal_initializer( stddev=model_config.initializer_range)) yesno_logits = yesno_model(pooled_output) supporting_fact_logits = supporting_fact_model(pooled_output) tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: (assignment_map, initialized_variable_names ) = checkpoint_utils.get_assignment_map_from_checkpoint( tvars, init_checkpoint) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: host_inputs = dict() span_prediction_loss = losses.BatchSpanCrossEntropyLoss() total_loss = 0 qa_loss = span_prediction_loss( logits=span_logits, annotation_begins=features["answer_annotation_begins"], annotation_ends=features["answer_annotation_ends"], annotation_labels=features["answer_annotation_labels"], block_ids=block_ids, num_replicas=num_replicas_concat, eps=1e-5) host_inputs["train_metrics/qa_loss"] = tf.expand_dims(qa_loss, 0) total_loss += qa_loss # example_mask = tf.cast(tf.not_equal(block_ids, 0), tf.float32) # yesno_loss = compute_pooled_loss(yesno_logits, answer_type, 3, # example_mask) # supporting_fact_loss = compute_supporting_facts_loss( # supporting_fact_logits, supporting_fact, example_mask) hotpot_qa_loss = hotpot_qa_losses.BatchSpanCrossEntropyLoss() yesno_loss, supporting_fact_loss = hotpot_qa_loss( yesno_logits, answer_type, supporting_fact_logits, supporting_fact, block_ids, eps=1e-5) host_inputs["train_metrics/yesno_loss"] = tf.expand_dims(yesno_loss, 0) total_loss += yesno_loss host_inputs["train_metrics/supporting_fact_loss"] = tf.expand_dims( supporting_fact_loss, 0) total_loss += supporting_fact_loss # Add regularization losses. if model.losses: total_loss += tf.math.add_n(model.losses) train_op = optimization.create_optimizer( total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu, optimizer, poly_power, start_warmup_step, learning_rate_schedule, reduce_loss_sum=True) host_inputs.update({ "global_step": tf.expand_dims(tf.train.get_or_create_global_step(), 0), "train_metrics/loss": tf.expand_dims(total_loss, 0), }) host_call = (functools.partial( record_summary_host_fn, metrics_dir=os.path.join(FLAGS.output_dir, "train_metrics")), host_inputs) output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn, host_call=host_call) elif mode == tf.estimator.ModeKeys.PREDICT: begin_logits_values, begin_logits_indices = tf.math.top_k( span_logits[:, :, 0], k=nbest_logits_for_eval, ) end_logits_values, end_logits_indices = tf.math.top_k( span_logits[:, :, 1], k=nbest_logits_for_eval, ) predictions = { "block_ids": tf.identity(block_ids), "begin_logits_values": begin_logits_values, "begin_logits_indices": begin_logits_indices, "end_logits_values": end_logits_values, "end_logits_indices": end_logits_indices, "token_ids": tf.identity(token_ids), "answer_type": answer_type, "yesno_logits": yesno_logits, "supporting_fact_logits": supporting_fact_logits, "is_supporting_fact": supporting_fact, } output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) else: raise ValueError("Only TRAIN and PREDICT modes is supported: %s" % mode) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" logging.info("*** Model: Params ***") for name in sorted(params.keys()): logging.info(" %s = %s", name, params[name]) logging.info("*** Model: Features ***") for name in sorted(features.keys()): logging.info(" name = %s, shape = %s", name, features[name].shape) model = modeling.ReadItTwiceBertModel( config=model_config, use_one_hot_embeddings=use_one_hot_embeddings) span_prediction_layer = modeling.SpanPredictionHead( intermediate_size=model_config.intermediate_size, dropout_rate=model_config.hidden_dropout_prob) # [batch_size, main_seq_length] token_ids = features["token_ids"] main_seq_length = tf.shape(token_ids)[1] block_ids = features["block_ids"] block_pos = features["block_pos"] annotation_begins = features.get("entity_annotation_begins") annotation_ends = features.get("entity_annotation_ends") annotation_labels = features.get("entity_annotation_labels") # Do not attend padding tokens # [batch_size, main_seq_length, main_seq_length] att_mask = tf.tile( tf.expand_dims(tf.not_equal(token_ids, padding_token_id), 1), [1, main_seq_length, 1]) att_mask = tf.cast(att_mask, dtype=tf.int32) main_output = model( token_ids=token_ids, training=(mode == tf.estimator.ModeKeys.TRAIN), block_ids=block_ids, block_pos=block_pos, att_mask=att_mask, annotation_begins=annotation_begins, annotation_ends=annotation_ends, annotation_labels=annotation_labels, enable_side_inputs=enable_side_inputs, num_replicas_concat=num_replicas_concat, cross_block_attention_mode=cross_block_attention_mode) span_logits = span_prediction_layer( hidden_states=main_output.final_hidden_states, token_ids=token_ids, padding_token_id=padding_token_id, ignore_prefix_length=features["prefix_length"], training=(mode == tf.estimator.ModeKeys.TRAIN)) is_summary_loss_enabled = (mode == tf.estimator.ModeKeys.TRAIN and summary_loss_weight is not None and summary_loss_weight > 0) if is_summary_loss_enabled: logging.info("Using summary prediction loss with weight %.3f", summary_loss_weight) summary_token_ids = features["summary_token_ids"] summary_labels = tf.roll(summary_token_ids, shift=-1, axis=1) decoder = modeling.ReadItTwiceDecoderModel( config=model_config, num_layers_override=summary_num_layers, num_cross_attention_heads=summary_num_cross_attention_heads, enable_default_side_input=summary_enable_default_side_input, use_one_hot_embeddings=use_one_hot_embeddings) summary_token_logits = decoder( token_ids=summary_token_ids, side_input=main_output.global_summary.states, token2side_input_att_mask=modeling.get_cross_block_att( block_ids, block_pos, main_output.global_summary.block_ids, main_output.global_summary.block_pos, cross_block_attention_mode="doc"), training=True) language_model_loss_fn = losses.LanguageModelLoss( decoder.get_token_embedding_table(), hidden_size=model_config.hidden_size) language_model_loss = language_model_loss_fn( summary_token_logits, summary_labels, padding_token_id=padding_token_id).loss else: language_model_loss = None tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: (assignment_map, initialized_variable_names ) = checkpoint_utils.get_assignment_map_from_checkpoint( tvars, init_checkpoint) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: host_inputs = dict() span_prediction_loss = losses.BatchSpanCrossEntropyLoss() qa_loss = span_prediction_loss( logits=span_logits, annotation_begins=features["answer_annotation_begins"], annotation_ends=features["answer_annotation_ends"], annotation_labels=features["answer_annotation_labels"], block_ids=block_ids, num_replicas=num_replicas_concat, eps=1e-5) host_inputs["train_metrics/qa_loss"] = tf.expand_dims(qa_loss, 0) if language_model_loss is not None: total_loss = ( 1.0 / (1.0 + summary_loss_weight) * qa_loss + summary_loss_weight / (1.0 + summary_loss_weight) * language_model_loss) host_inputs["train_metrics/summary_lm_loss"] = tf.expand_dims( language_model_loss, 0) else: total_loss = qa_loss # Add regularization losses. if model.losses: total_loss += tf.math.add_n(model.losses) train_op = optimization.create_optimizer(total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu, optimizer, poly_power, start_warmup_step, learning_rate_schedule, reduce_loss_sum=True) host_inputs.update({ "global_step": tf.expand_dims(tf.train.get_or_create_global_step(), 0), "train_metrics/loss": tf.expand_dims(total_loss, 0), }) host_call = (functools.partial(record_summary_host_fn, metrics_dir=os.path.join( FLAGS.output_dir, "train_metrics")), host_inputs) output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn, host_call=host_call) elif mode == tf.estimator.ModeKeys.PREDICT: begin_logits_values, begin_logits_indices = tf.math.top_k( span_logits[:, :, 0], k=nbest_logits_for_eval, ) end_logits_values, end_logits_indices = tf.math.top_k( span_logits[:, :, 1], k=nbest_logits_for_eval, ) predictions = { "block_ids": tf.identity(block_ids), "begin_logits_values": begin_logits_values, "begin_logits_indices": begin_logits_indices, "end_logits_values": end_logits_values, "end_logits_indices": end_logits_indices, "token_ids": tf.identity(token_ids), } output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) else: raise ValueError("Only TRAIN and PREDICT modes is supported: %s" % (mode)) return output_spec
def test_cross_entropy_loss(self, seq_length, block_ids, annotation_begins, annotation_ends, annotation_labels): np.random.seed(31415) unique_block_ids = set(block_ids) batch_size = len(block_ids) num_annotations = len(annotation_begins[0]) for i in range(batch_size): self.assertLen(annotation_begins[i], num_annotations) self.assertLen(annotation_ends[i], num_annotations) self.assertLen(annotation_labels[i], num_annotations) logits = np.random.random((batch_size, seq_length, 2)) logits = (logits - 0.5) * 100 logits = logits.astype(np.float32) expected_loss_np = 0 for block_id in unique_block_ids: current_indices = [ i for i in range(batch_size) if block_ids[i] == block_id ] current_begin_logits = np.concatenate( [logits[i, :, 0] for i in current_indices]) current_end_logits = np.concatenate( [logits[i, :, 1] for i in current_indices]) current_begin_probs = scipy.special.softmax(current_begin_logits) current_end_probs = scipy.special.softmax(current_end_logits) current_begins, current_ends = set(), set() for i, sample_index in enumerate(current_indices): for j in range(num_annotations): if annotation_labels[sample_index][j] > 0: current_begins.add(annotation_begins[sample_index][j] + i * seq_length) current_ends.add(annotation_ends[sample_index][j] + i * seq_length) if not current_begins: self.assertEmpty(current_ends) continue else: self.assertNotEmpty(current_ends) expected_loss_np -= ( np.log(sum([current_begin_probs[i] for i in current_begins])) + np.log(sum([current_end_probs[i] for i in current_ends]))) logits_tf = tf.compat.v1.placeholder_with_default( logits, shape=[None, None, 2]) block_ids_tf = tf.compat.v1.placeholder_with_default(block_ids, shape=[None]) annotation_begins_tf = tf.compat.v1.placeholder_with_default( annotation_begins, shape=[None, None]) annotation_ends_tf = tf.compat.v1.placeholder_with_default( annotation_ends, shape=[None, None]) annotation_labels_tf = tf.compat.v1.placeholder_with_default( annotation_labels, shape=[None, None]) loss_layer = losses.BatchSpanCrossEntropyLoss() init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) self.evaluate(init_op) actual_loss = loss_layer(logits_tf, annotation_begins_tf, annotation_ends_tf, annotation_labels_tf, block_ids_tf) actual_loss_value = self.evaluate(actual_loss) self.assertNear(actual_loss_value, expected_loss_np, err=1e-4)