def model_fn(features, labels, mode, params): """The `model_fn` for TPUEstimator.""" del labels, params # Not used. tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s", name, features[name].shape) is_training = (mode == tf.estimator.ModeKeys.TRAIN) entity_ids = search_utils.load_database( "entity_ids", [qa_config.num_entities, qa_config.max_entity_len], entity_id_checkpoint, dtype=tf.int32) entity_mask = search_utils.load_database( "entity_mask", [qa_config.num_entities, qa_config.max_entity_len], entity_mask_checkpoint) if FLAGS.model_type == "drkit": # Initialize sparse tensor of ent2ment. with tf.device("/cpu:0"): tf_e2m_data, tf_e2m_indices, tf_e2m_rowsplits = ( search_utils.load_ragged_matrix("ent2ment", e2m_checkpoint)) with tf.name_scope("RaggedConstruction_e2m"): e2m_ragged_ind = tf.RaggedTensor.from_row_splits( values=tf_e2m_indices, row_splits=tf_e2m_rowsplits, validate=False) e2m_ragged_val = tf.RaggedTensor.from_row_splits( values=tf_e2m_data, row_splits=tf_e2m_rowsplits, validate=False) tf_m2e_map = search_utils.load_database("coref", [mips_config.num_mentions], m2e_checkpoint, dtype=tf.int32) total_loss, predictions = create_model_fn( bert_config=bert_config, qa_config=qa_config, mips_config=mips_config, is_training=is_training, features=features, ent2ment_ind=e2m_ragged_ind, ent2ment_val=e2m_ragged_val, ment2ent_map=tf_m2e_map, entity_ids=entity_ids, entity_mask=entity_mask, use_one_hot_embeddings=use_one_hot_embeddings, summary_obj=summary_obj, num_preds=FLAGS.num_preds, is_excluding=FLAGS.is_excluding, ) elif FLAGS.model_type == "drfact": # Initialize sparse tensor of ent2fact. with tf.device("/cpu:0"): # Note: cpu or gpu? tf_e2f_data, tf_e2f_indices, tf_e2f_rowsplits = ( search_utils.load_ragged_matrix("ent2fact", e2f_checkpoint)) with tf.name_scope("RaggedConstruction_e2f"): e2f_ragged_ind = tf.RaggedTensor.from_row_splits( values=tf_e2f_indices, row_splits=tf_e2f_rowsplits, validate=False) e2f_ragged_val = tf.RaggedTensor.from_row_splits( values=tf_e2f_data, row_splits=tf_e2f_rowsplits, validate=False) # Initialize sparse tensor of fact2ent. with tf.device("/cpu:0"): tf_f2e_data, tf_f2e_indices, tf_f2e_rowsplits = ( search_utils.load_ragged_matrix("fact2ent", f2e_checkpoint)) with tf.name_scope("RaggedConstruction_f2e"): f2e_ragged_ind = tf.RaggedTensor.from_row_splits( values=tf_f2e_indices, row_splits=tf_f2e_rowsplits, validate=False) f2e_ragged_val = tf.RaggedTensor.from_row_splits( values=tf_f2e_data, row_splits=tf_f2e_rowsplits, validate=False) # Initialize sparse tensor of fact2fact. with tf.device("/cpu:0"): tf_f2f_data, tf_f2f_indices, tf_f2f_rowsplits = ( search_utils.load_ragged_matrix("fact2fact", f2f_checkpoint)) with tf.name_scope("RaggedConstruction_f2f"): f2f_ragged_ind = tf.RaggedTensor.from_row_splits( values=tf_f2f_indices, row_splits=tf_f2f_rowsplits, validate=False) f2f_ragged_val = tf.RaggedTensor.from_row_splits( values=tf_f2f_data, row_splits=tf_f2f_rowsplits, validate=False) total_loss, predictions = create_model_fn( bert_config=bert_config, qa_config=qa_config, fact_mips_config=fact_mips_config, is_training=is_training, features=features, ent2fact_ind=e2f_ragged_ind, ent2fact_val=e2f_ragged_val, fact2ent_ind=f2e_ragged_ind, fact2ent_val=f2e_ragged_val, fact2fact_ind=f2f_ragged_ind, fact2fact_val=f2f_ragged_val, entity_ids=entity_ids, entity_mask=entity_mask, use_one_hot_embeddings=use_one_hot_embeddings, summary_obj=summary_obj, num_preds=FLAGS.num_preds, is_excluding=FLAGS.is_excluding, ) tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: (assignment_map, initialized_variable_names) = get_assignment_map_from_checkpoint( tvars, init_checkpoint, load_only_bert=qa_config.load_only_bert) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: one_mb = tf.constant(1024 * 1024, dtype=tf.int64) devices = tf.config.experimental.list_logical_devices("GPU") memory_footprints = [] for device in devices: memory_footprint = tf.print( device.name, contrib_memory_stats.MaxBytesInUse() / one_mb, " / ", contrib_memory_stats.BytesLimit() / one_mb) memory_footprints.append(memory_footprint) with tf.control_dependencies(memory_footprints): train_op = create_optimizer(total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu, False) output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.PREDICT: output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) else: raise ValueError("Only TRAIN and PREDICT modes are supported: %s" % (mode)) return output_spec
def train(self, sess): """Main training function/loop. Args: sess: a tf session object """ # For debugging/pushing limits of model gpu_mb = tf.constant(1024*1024, dtype=tf.int64) gpus = tf.config.experimental.list_logical_devices("GPU") memory_footprints = [] for gpu in gpus: with tf.device(gpu.name): memory_footprint = tf.Print( tf.constant(0), [ contrib_memory_stats.BytesLimit() / gpu_mb, contrib_memory_stats.MaxBytesInUse() / gpu_mb ], message=gpu.name) memory_footprints.append(memory_footprint) epochs = FLAGS.num_epochs prints = FLAGS.log_frequency training_start_time = time.time() epochs_start_time = time.time() num_batches = max(int(len(self.train_examples)/self.batch_size), 1) tf.logging.info("Num batches per epoch: {}".format(num_batches)) # Additional logging losses = np.zeros((epochs * num_batches)) accuracies = np.zeros((epochs * num_batches)) for epoch in range(epochs): random.shuffle(self.train_examples) for batch in range(num_batches): batch_no = epoch * num_batches + batch should_sample = (batch_no % prints == 0) train_ops_to_run = { "train_step": self.train_step, "loss": self.model.loss, "accuracy": self.model.accuracy, "accuracy_per_example": self.model.accuracy_per_ex, "output_relations": self.model.log_decoded_relations, } if should_sample: train_ops_to_run["props"] = self.model.property_loss train_ops_to_run["regularization"] = self.model.regularization for i, memory_footprint in enumerate(memory_footprints): train_ops_to_run["memory_footprint_{}".format(i)] = memory_footprint batch_examples = self.train_examples[batch: batch + self.batch_size] feed_dict = self._compute_feed_dict(batch_examples) train_output = sess.run(train_ops_to_run, feed_dict) losses[batch_no] = train_output["loss"] accuracies[batch_no] = train_output["accuracy"] if should_sample: # Timing info epochs_end_time = time.time() epochs_time_str = str(datetime.timedelta( seconds=epochs_end_time - epochs_start_time)) epochs_start_time = epochs_end_time precision, recall = self._evaluate_sample(sess, train_output, feed_dict, batch_examples, full_log=True) if precision and recall: pr_string = "\tPrecision: {:.3f}\tRecall {:.3f}".format( np.mean(precision), np.mean(recall)) else: pr_string = "" tf.logging.info( ("[{}] Epoch: {}.{}\tLoss: {:.3f}|{:.3f}|{:.3f}\t" + "Accuracy: {:.3f}{}\n").format( epochs_time_str, epoch, batch, train_output["loss"], train_output["props"], train_output["regularization"], train_output["accuracy"], pr_string)) # Do a dev run, it doesn't take that long self.evaluate(sess, full=False) training_end_time = time.time() tf.logging.info("Training took: %s" % str(datetime.timedelta( seconds=training_end_time - training_start_time))) if self.ckpt_dir is not None: save_path = self.saver.save(sess, os.path.join(self.ckpt_dir, "model.ckpt")) tf.logging.info("Saved model at {}".format(save_path))