def model_fn(features, labels, mode, params): """The `model_fn` for TPUEstimator.""" del labels, params # Unused. tf.logging.info("*** Features ***") for name in sorted(features): tf.logging.info(" name = %s, shape = %s", name, features[name].shape) masked_lm_positions = features["masked_lm_positions"] masked_lm_ids = features["masked_lm_ids"] masked_lm_weights = features["masked_lm_weights"] next_sentence_labels = features["next_sentence_labels"] model = table_bert.create_model( features=features, mode=mode, bert_config=bert_config, restrict_attention_mode=restrict_attention_mode, restrict_attention_bucket_size=restrict_attention_bucket_size, restrict_attention_header_size=restrict_attention_header_size, restrict_attention_row_heads_ratio= restrict_attention_row_heads_ratio, disabled_features=disabled_features, disable_position_embeddings=disable_position_embeddings, reset_position_index_per_cell=reset_position_index_per_cell, proj_value_length=proj_value_length, attention_bias_disabled=attention_bias_disabled, attention_bias_use_relative_scalar_only= attention_bias_use_relative_scalar_only, ) (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs, masked_lm_predictions) = _get_masked_lm_output( bert_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights) (next_sentence_loss, next_sentence_example_loss, next_sentence_log_probs) = _get_next_sentence_output( bert_config, model.get_pooled_output(), next_sentence_labels) total_loss = masked_lm_loss + next_sentence_loss tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: init_tvars = [ tvar for tvar in tvars if "position_embeddings" not in tvar.name ] (assignment_map, initialized_variable_names ) = modeling.get_assignment_map_from_checkpoint( init_tvars, init_checkpoint) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) output_spec = None if mode == tf_estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer(total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) output_spec = tf_estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf_estimator.ModeKeys.EVAL: def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, next_sentence_example_loss, next_sentence_log_probs, next_sentence_labels): """Computes the loss and accuracy of the model.""" masked_lm_log_probs = tf.reshape( masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]]) masked_lm_predictions = tf.argmax(masked_lm_log_probs, axis=-1, output_type=tf.int32) masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) masked_lm_accuracy = tf.metrics.accuracy( labels=masked_lm_ids, predictions=masked_lm_predictions, weights=masked_lm_weights) masked_lm_mean_loss = tf.metrics.mean( values=masked_lm_example_loss, weights=masked_lm_weights) next_sentence_log_probs = tf.reshape( next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]]) next_sentence_predictions = tf.argmax(next_sentence_log_probs, axis=-1, output_type=tf.int32) next_sentence_labels = tf.reshape(next_sentence_labels, [-1]) next_sentence_accuracy = tf.metrics.accuracy( labels=next_sentence_labels, predictions=next_sentence_predictions) next_sentence_mean_loss = tf.metrics.mean( values=next_sentence_example_loss) return { "masked_lm_accuracy": masked_lm_accuracy, "masked_lm_loss": masked_lm_mean_loss, "next_sentence_accuracy": next_sentence_accuracy, "next_sentence_loss": next_sentence_mean_loss, } eval_metrics = (metric_fn, [ masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, next_sentence_example_loss, next_sentence_log_probs, next_sentence_labels ]) output_spec = tf_estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) elif mode == tf_estimator.ModeKeys.PREDICT: predictions = { "masked_lm_predictions": masked_lm_predictions, } output_spec = tf_estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, predictions=predictions, scaffold_fn=scaffold_fn) else: raise ValueError("Unsupported mode: %s" % mode) return output_spec
def model_fn(features, labels, mode, params): """The `model_fn` for TPUEstimator.""" del labels, params # Unused. # TODO(thomasmueller) Add support for this. if (config.use_out_of_core_negatives and config.use_mined_negatives and config.mask_repeated_questions): raise ValueError("Unsupported combination of options.") tf.logging.info("*** Features ***") for name in sorted(features): tf.logging.info(" name = %s, shape = %s", name, features[name].shape) if config.ignore_table_content: features["input_mask"] = tf.where( features["row_ids"] > 0, tf.zeros_like(features["input_mask"]), tf.ones_like(features["input_mask"])) features_table_names = [ "input_ids", "input_mask", "segment_ids", "column_ids", "row_ids", "prev_label_ids", "column_ranks", "inv_column_ranks", "numeric_relations" ] features_table = {} for name in features_table_names: if config.use_mined_negatives: # split each feature in half, and concat vertically. feature_positive_table, feature_negative_table = tf.split( features[name], num_or_size_splits=2, axis=1) features_table[name] = tf.concat( [feature_positive_table, feature_negative_table], axis=0) else: features_table[name] = features[name] tf.logging.info("*** Features table ***") for name in sorted(features_table): tf.logging.info(" name = %s, shape = %s", name, features[name].shape) table_model = table_bert.create_model( features=features_table, disabled_features=config.disabled_features, mode=mode, bert_config=config.bert_config, ) # Arrange features for query, such that it is assigned with an empty table. empty_table_features = tf.zeros_like( features["question_input_mask"])[:, :config.max_query_length] features_query = { "input_ids": features["question_input_ids"][:, :config.max_query_length], "input_mask": features["question_input_mask"][:, :config.max_query_length], "segment_ids": empty_table_features, "column_ids": empty_table_features, "row_ids": empty_table_features, "prev_label_ids": empty_table_features, "column_ranks": empty_table_features, "inv_column_ranks": empty_table_features, "numeric_relations": empty_table_features, } query_model = table_bert.create_model( features=features_query, disabled_features=config.disabled_features, mode=mode, bert_config=config.bert_config, ) table_hidden_representation = table_model.get_pooled_output() query_hidden_representation = query_model.get_pooled_output() if config.down_projection_dim > 0: table_projection = _get_projection_matrix( "table_projection", num_columns=table_hidden_representation.shape[1]) query_projection = _get_projection_matrix( "text_projection", num_columns=query_hidden_representation.shape[1]) # <float32>[batch_size * num_tables, down_projection_dim] table_rep = _get_type_representation(table_hidden_representation, table_projection) # <float32>[batch_size, down_projection_dim] query_rep = _get_type_representation(query_hidden_representation, query_projection) else: table_rep = table_hidden_representation query_rep = query_hidden_representation batch_size = tf.shape(query_rep)[0] # Identity matrix, as gold logits are on the diagonal. labels_single_table = tf.eye(batch_size) if config.use_mined_negatives: # <int64>[batch_size, batch_size * num_tables] labels = tf.concat( [labels_single_table, tf.zeros_like(labels_single_table)], axis=1) else: labels = labels_single_table # <int64>[batch_size * num_tables] table_id_hash = tf.reshape(tf.transpose(features["table_id_hash"]), shape=[-1]) # <int64>[batch_size, 1] table_id_hash_transposed = features["table_id_hash"][..., :1] # <int64>[batch_size, 1] question_hash = features["question_hash"] # <int64>[1, batch_size] question_hash_transposed = tf.transpose(question_hash) if config.use_tpu and config.use_out_of_core_negatives: data = get_updates_for_use_tpu_with_out_of_core_negatives( ModelBuilderData(table_rep, table_id_hash, question_hash, labels)) table_rep = data.table_rep table_id_hash = data.table_id_hash question_hash = data.question_hash labels = data.labels # <float32>[batch_size, batch_size|global_batch_size * num_tables] logits = tf.matmul(query_rep, table_rep, transpose_b=True) if config.mask_repeated_tables: # Matrix of non-trivial repeated tables # <bool>[batch_size, batch_size|global_batch_size * num_tables] repeated_tables = tf.math.equal( tf.expand_dims(table_id_hash, axis=0), table_id_hash_transposed) & tf.math.equal(labels, 0) logits = tf.where(repeated_tables, tf.zeros_like(logits) - _INF, logits) if config.mask_repeated_questions: logits = _apply_reapated_text_masking( config, question_hash, question_hash_transposed, labels, logits, ) total_loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits) tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None init_checkpoint = config.init_checkpoint if init_checkpoint: (assignment_maps, initialized_variable_names) = _get_assignment_map_from_checkpoint( tvars, init_checkpoint, config.init_from_single_encoder) if config.use_tpu: def tpu_scaffold(): for assignment_map in assignment_maps: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: for assignment_map in assignment_maps: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) output_spec = None if mode == tf_estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer( total_loss, config.learning_rate, config.num_train_steps, config.num_warmup_steps, config.use_tpu, grad_clipping=config.grad_clipping) output_spec = tf_estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf_estimator.ModeKeys.EVAL: eval_metrics = (_calculate_eval_metrics_fn, [total_loss, logits, labels]) output_spec = tf_estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: if config.use_mined_negatives: table_rep_gold, _ = tf.split(table_rep, num_or_size_splits=2, axis=0) else: table_rep_gold = table_rep predictions = { "query_rep": query_rep, "table_rep": table_rep_gold, } # Only available when predicting on GPU. if "table_id" in features: predictions["table_id"] = features["table_id"] if "question_id" in features: predictions["query_id"] = features["question_id"] output_spec = tf_estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def model_fn(features, labels, mode, params): """The `model_fn` for TPUEstimator.""" del labels # Unused. tf.logging.info("*** Features ***") for name in sorted(features): tf.logging.info(" name = %s, shape = %s", name, features[name].shape) label_ids = features["label_ids"] input_mask = features["input_mask"] row_ids = features["row_ids"] column_ids = features["column_ids"] # Table cells only, without question tokens and table headers. table_mask = tf.where(row_ids > 0, tf.ones_like(row_ids), tf.zeros_like(row_ids)) do_model_aggregation = config.num_aggregation_labels > 0 aggregation_function_id = (tf.squeeze( features["aggregation_function_id"], axis=[1]) if do_model_aggregation else None) do_model_classification = config.num_classification_labels > 0 classification_class_index = (tf.squeeze( features["classification_class_index"], axis=[1]) if do_model_classification else None) is_training = (mode == tf.estimator.ModeKeys.TRAIN) model = table_bert.create_model( features=features, mode=mode, bert_config=config.bert_config, disabled_features=config.disabled_features, disable_position_embeddings=config.disable_position_embeddings, reset_position_index_per_cell=config.reset_position_index_per_cell, proj_value_length=config.proj_value_length, ) answer, numeric_values, numeric_values_scale = ( utils.extract_answer_from_features( features=features, use_answer_as_supervision=config.use_answer_as_supervision)) outputs = _get_classification_outputs( config=config, output_layer=model.get_sequence_output(), output_layer_aggregation=model.get_pooled_output(), label_ids=label_ids, input_mask=input_mask, table_mask=table_mask, aggregation_function_id=aggregation_function_id, answer=answer, numeric_values=numeric_values, numeric_values_scale=numeric_values_scale, is_training=is_training, row_ids=row_ids, column_ids=column_ids, classification_class_index=classification_class_index) total_loss = outputs.total_loss tvars = tf.trainable_variables() if config.reset_output_cls: tvars = [ tvar for tvar in tvars if ("output_weights_cls" not in tvar.name and "output_bias_cls" not in tvar.name) ] initialized_variable_names = set() scaffold_fn = None init_from_checkpoints = [] def add_init_checkpoint(init_checkpoint, scope=None): if not init_checkpoint: return (assignment_map, initialized_variables ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint, scope=scope) initialized_variable_names.update(initialized_variables.keys()) init_from_checkpoints.append((init_checkpoint, assignment_map)) add_init_checkpoint(config.init_checkpoint) if init_from_checkpoints: if config.use_tpu: def tpu_scaffold(): for init_checkpoint, assignment_map in init_from_checkpoints: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: for init_checkpoint, assignment_map in init_from_checkpoints: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) fail_if_missing = init_from_checkpoints and params.get( "fail_if_missing_variables_in_checkpoint", False) tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" elif fail_if_missing: if "layer_norm" not in var.name and "LayerNorm" not in var.name: tf.logging.fatal("Variable not found in checkpoint: %s", var.name) tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer( total_loss, config.learning_rate, config.num_train_steps, config.num_warmup_steps, config.use_tpu, gradient_accumulation_steps=params.get( "gradient_accumulation_steps", 1), grad_clipping=config.grad_clipping) output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (_calculate_eval_metrics_fn, [ total_loss, label_ids, outputs.logits, input_mask, aggregation_function_id, outputs.logits_aggregation, classification_class_index, outputs.logits_cls, ]) output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: predictions = { "probabilities": outputs.probs, "input_ids": features["input_ids"], "column_ids": features["column_ids"], "row_ids": features["row_ids"], "segment_ids": features["segment_ids"], "question_id_ints": features["question_id_ints"], } if "question_id" in features: # Only available when predicting on GPU. predictions["question_id"] = features["question_id"] del predictions["question_id_ints"] if do_model_aggregation: predictions.update({ "gold_aggr": features["aggregation_function_id"], "pred_aggr": tf.argmax( outputs.logits_aggregation, axis=-1, output_type=tf.int32, ) }) if do_model_classification: predictions.update({ "gold_cls": features["classification_class_index"], "pred_cls": tf.argmax( outputs.logits_cls, axis=-1, output_type=tf.int32, ) }) if config.num_classification_labels == 2: predictions.update({ "logits_cls": outputs.logits_cls[:, 1] - outputs.logits_cls[:, 0] }) else: predictions.update({"logits_cls": outputs.logits_cls}) if outputs.span_indexes is not None and outputs.span_logits is not None: predictions.update({"span_indexes": outputs.span_indexes}) predictions.update({"span_logits": outputs.span_logits}) if custom_prediction_keys: predictions = { key: predictions[key] for key in custom_prediction_keys } output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def model_fn(features, labels, mode, params): """The `model_fn` for TPUEstimator.""" del labels # Unused. tf.logging.info("*** Features ***") for name in sorted(features): tf.logging.info(" name = %s, shape = %s", name, features[name].shape) label_ids = features["label_ids"] input_mask = features["input_mask"] row_ids = features["row_ids"] column_ids = features["column_ids"] # Table cells only, without question tokens and table headers. table_mask = tf.where(row_ids > 0, tf.ones_like(row_ids), tf.zeros_like(row_ids)) do_model_aggregation = config.num_aggregation_labels > 0 aggregation_function_id = (tf.squeeze( features["aggregation_function_id"], axis=[1]) if do_model_aggregation else None) do_model_classification = config.num_classification_labels > 0 classification_class_index = (tf.squeeze( features["classification_class_index"], axis=[1]) if do_model_classification else None) is_training = (mode == tf.estimator.ModeKeys.TRAIN) model = table_bert.create_model( features=features, mode=mode, bert_config=config.bert_config, disabled_features=config.disabled_features, disable_position_embeddings=config.disable_position_embeddings) if config.use_answer_as_supervision: answer = tf.squeeze(features["answer"], axis=[1]) numeric_values = features["numeric_values"] numeric_values_scale = features["numeric_values_scale"] else: answer = None numeric_values = None numeric_values_scale = None (total_loss, logits, logits_aggregation, probabilities, logits_cls) = _get_classification_outputs( config=config, output_layer=model.get_sequence_output(), output_layer_aggregation=model.get_pooled_output(), label_ids=label_ids, input_mask=input_mask, table_mask=table_mask, aggregation_function_id=aggregation_function_id, answer=answer, numeric_values=numeric_values, numeric_values_scale=numeric_values_scale, is_training=is_training, row_ids=row_ids, column_ids=column_ids, classification_class_index=classification_class_index) tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None init_checkpoint = config.init_checkpoint if init_checkpoint: (assignment_map, initialized_variable_names ) = modeling.get_assignment_map_from_checkpoint( tvars, init_checkpoint) if config.use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer( total_loss, config.learning_rate, config.num_train_steps, config.num_warmup_steps, config.use_tpu, gradient_accumulation_steps=params.get( "gradient_accumulation_steps", 1), grad_clipping=config.grad_clipping) output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (_calculate_eval_metrics_fn, [ total_loss, label_ids, logits, input_mask, aggregation_function_id, logits_aggregation, classification_class_index, logits_cls ]) output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: predictions = { "embeddings": model.get_sequence_output(), "probabilities": probabilities, "column_ids": features["column_ids"], "row_ids": features["row_ids"], "segment_ids": features["segment_ids"], "question_id_ints": features["question_id_ints"], } # TODO Remove once the data has been updated. if "question_id" in features: # Only available when predicting on GPU. predictions["question_id"] = features["question_id"] if do_model_aggregation: predictions.update({ "gold_aggr": features["aggregation_function_id"], "pred_aggr": tf.argmax(logits_aggregation, axis=-1, output_type=tf.int32) }) if do_model_classification: predictions.update({ "gold_cls": features["classification_class_index"], "pred_cls": tf.argmax(logits_cls, axis=-1, output_type=tf.int32) }) if config.num_classification_labels == 2: predictions.update( {"logits_cls": logits_cls[:, 1] - logits_cls[:, 0]}) else: predictions.update({"logits_cls": logits_cls}) output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec