def model_fn(features, labels, mode, params): """The `model_fn` for TPUEstimator.""" del labels, params # Not used. tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s", name, features[name].shape) is_training = (mode == tf.estimator.ModeKeys.TRAIN) entity_ids = search_utils.load_database( "entity_ids", [qa_config.num_entities, qa_config.max_entity_len], entity_id_checkpoint, dtype=tf.int32) entity_mask = search_utils.load_database( "entity_mask", [qa_config.num_entities, qa_config.max_entity_len], entity_mask_checkpoint) if FLAGS.model_type == "drkit": # Initialize sparse tensor of ent2ment. with tf.device("/cpu:0"): tf_e2m_data, tf_e2m_indices, tf_e2m_rowsplits = ( search_utils.load_ragged_matrix("ent2ment", e2m_checkpoint)) with tf.name_scope("RaggedConstruction_e2m"): e2m_ragged_ind = tf.RaggedTensor.from_row_splits( values=tf_e2m_indices, row_splits=tf_e2m_rowsplits, validate=False) e2m_ragged_val = tf.RaggedTensor.from_row_splits( values=tf_e2m_data, row_splits=tf_e2m_rowsplits, validate=False) tf_m2e_map = search_utils.load_database("coref", [mips_config.num_mentions], m2e_checkpoint, dtype=tf.int32) total_loss, predictions = create_model_fn( bert_config=bert_config, qa_config=qa_config, mips_config=mips_config, is_training=is_training, features=features, ent2ment_ind=e2m_ragged_ind, ent2ment_val=e2m_ragged_val, ment2ent_map=tf_m2e_map, entity_ids=entity_ids, entity_mask=entity_mask, use_one_hot_embeddings=use_one_hot_embeddings, summary_obj=summary_obj, num_preds=FLAGS.num_preds, is_excluding=FLAGS.is_excluding, ) elif FLAGS.model_type == "drfact": # Initialize sparse tensor of ent2fact. with tf.device("/cpu:0"): # Note: cpu or gpu? tf_e2f_data, tf_e2f_indices, tf_e2f_rowsplits = ( search_utils.load_ragged_matrix("ent2fact", e2f_checkpoint)) with tf.name_scope("RaggedConstruction_e2f"): e2f_ragged_ind = tf.RaggedTensor.from_row_splits( values=tf_e2f_indices, row_splits=tf_e2f_rowsplits, validate=False) e2f_ragged_val = tf.RaggedTensor.from_row_splits( values=tf_e2f_data, row_splits=tf_e2f_rowsplits, validate=False) # Initialize sparse tensor of fact2ent. with tf.device("/cpu:0"): tf_f2e_data, tf_f2e_indices, tf_f2e_rowsplits = ( search_utils.load_ragged_matrix("fact2ent", f2e_checkpoint)) with tf.name_scope("RaggedConstruction_f2e"): f2e_ragged_ind = tf.RaggedTensor.from_row_splits( values=tf_f2e_indices, row_splits=tf_f2e_rowsplits, validate=False) f2e_ragged_val = tf.RaggedTensor.from_row_splits( values=tf_f2e_data, row_splits=tf_f2e_rowsplits, validate=False) # Initialize sparse tensor of fact2fact. with tf.device("/cpu:0"): tf_f2f_data, tf_f2f_indices, tf_f2f_rowsplits = ( search_utils.load_ragged_matrix("fact2fact", f2f_checkpoint)) with tf.name_scope("RaggedConstruction_f2f"): f2f_ragged_ind = tf.RaggedTensor.from_row_splits( values=tf_f2f_indices, row_splits=tf_f2f_rowsplits, validate=False) f2f_ragged_val = tf.RaggedTensor.from_row_splits( values=tf_f2f_data, row_splits=tf_f2f_rowsplits, validate=False) total_loss, predictions = create_model_fn( bert_config=bert_config, qa_config=qa_config, fact_mips_config=fact_mips_config, is_training=is_training, features=features, ent2fact_ind=e2f_ragged_ind, ent2fact_val=e2f_ragged_val, fact2ent_ind=f2e_ragged_ind, fact2ent_val=f2e_ragged_val, fact2fact_ind=f2f_ragged_ind, fact2fact_val=f2f_ragged_val, entity_ids=entity_ids, entity_mask=entity_mask, use_one_hot_embeddings=use_one_hot_embeddings, summary_obj=summary_obj, num_preds=FLAGS.num_preds, is_excluding=FLAGS.is_excluding, ) tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: (assignment_map, initialized_variable_names) = get_assignment_map_from_checkpoint( tvars, init_checkpoint, load_only_bert=qa_config.load_only_bert) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: one_mb = tf.constant(1024 * 1024, dtype=tf.int64) devices = tf.config.experimental.list_logical_devices("GPU") memory_footprints = [] for device in devices: memory_footprint = tf.print( device.name, contrib_memory_stats.MaxBytesInUse() / one_mb, " / ", contrib_memory_stats.BytesLimit() / one_mb) memory_footprints.append(memory_footprint) with tf.control_dependencies(memory_footprints): train_op = create_optimizer(total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu, False) output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.PREDICT: output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) else: raise ValueError("Only TRAIN and PREDICT modes are supported: %s" % (mode)) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s", name, features[name].shape) is_training = (mode == tf.estimator.ModeKeys.TRAIN) # Initialize sparse tensors. with tf.device("/cpu:0"): tf_e2m_data, tf_e2m_indices, tf_e2m_rowsplits = ( search_utils.load_ragged_matrix("ent2ment", e2m_checkpoint)) with tf.name_scope("RaggedConstruction"): e2m_ragged_ind = tf.RaggedTensor.from_row_splits( values=tf_e2m_indices, row_splits=tf_e2m_rowsplits, validate=False) e2m_ragged_val = tf.RaggedTensor.from_row_splits( values=tf_e2m_data, row_splits=tf_e2m_rowsplits, validate=False) tf_m2e_map = search_utils.load_database("coref", [mips_config.num_mentions], m2e_checkpoint, dtype=tf.int32) entity_ids = search_utils.load_database( "entity_ids", [qa_config.num_entities, qa_config.max_entity_len], entity_id_checkpoint, dtype=tf.int32) entity_mask = search_utils.load_database( "entity_mask", [qa_config.num_entities, qa_config.max_entity_len], entity_mask_checkpoint) _, predictions = create_model_fn( bert_config=bert_config, qa_config=qa_config, mips_config=mips_config, is_training=is_training, features=features, ent2ment_ind=e2m_ragged_ind, ent2ment_val=e2m_ragged_val, ment2ent_map=tf_m2e_map, entity_ids=entity_ids, entity_mask=entity_mask, use_one_hot_embeddings=use_one_hot_embeddings, summary_obj=summary_obj) tvars = tf.trainable_variables() scaffold_fn = None if init_checkpoint: assignment_map, _ = get_assignment_map_from_checkpoint( tvars, init_checkpoint, load_only_bert=qa_config.load_only_bert) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) output_spec = None if mode == tf.estimator.ModeKeys.PREDICT: output_spec = contrib_tpu.TPUEstimatorSpec(mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) else: raise ValueError("Only PREDICT mode is supported: %s" % (mode)) return output_spec