Exemplo n.º 1
0
    def model_fn(features, labels, mode, params):
        """The `model_fn` for TPUEstimator."""
        del labels, params  # Not used.
        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s", name,
                            features[name].shape)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        entity_ids = search_utils.load_database(
            "entity_ids", [qa_config.num_entities, qa_config.max_entity_len],
            entity_id_checkpoint,
            dtype=tf.int32)
        entity_mask = search_utils.load_database(
            "entity_mask", [qa_config.num_entities, qa_config.max_entity_len],
            entity_mask_checkpoint)

        if FLAGS.model_type == "drkit":
            # Initialize sparse tensor of ent2ment.
            with tf.device("/cpu:0"):
                tf_e2m_data, tf_e2m_indices, tf_e2m_rowsplits = (
                    search_utils.load_ragged_matrix("ent2ment",
                                                    e2m_checkpoint))
                with tf.name_scope("RaggedConstruction_e2m"):
                    e2m_ragged_ind = tf.RaggedTensor.from_row_splits(
                        values=tf_e2m_indices,
                        row_splits=tf_e2m_rowsplits,
                        validate=False)
                    e2m_ragged_val = tf.RaggedTensor.from_row_splits(
                        values=tf_e2m_data,
                        row_splits=tf_e2m_rowsplits,
                        validate=False)

            tf_m2e_map = search_utils.load_database("coref",
                                                    [mips_config.num_mentions],
                                                    m2e_checkpoint,
                                                    dtype=tf.int32)

            total_loss, predictions = create_model_fn(
                bert_config=bert_config,
                qa_config=qa_config,
                mips_config=mips_config,
                is_training=is_training,
                features=features,
                ent2ment_ind=e2m_ragged_ind,
                ent2ment_val=e2m_ragged_val,
                ment2ent_map=tf_m2e_map,
                entity_ids=entity_ids,
                entity_mask=entity_mask,
                use_one_hot_embeddings=use_one_hot_embeddings,
                summary_obj=summary_obj,
                num_preds=FLAGS.num_preds,
                is_excluding=FLAGS.is_excluding,
            )
        elif FLAGS.model_type == "drfact":
            # Initialize sparse tensor of ent2fact.
            with tf.device("/cpu:0"):  # Note: cpu or gpu?
                tf_e2f_data, tf_e2f_indices, tf_e2f_rowsplits = (
                    search_utils.load_ragged_matrix("ent2fact",
                                                    e2f_checkpoint))
                with tf.name_scope("RaggedConstruction_e2f"):
                    e2f_ragged_ind = tf.RaggedTensor.from_row_splits(
                        values=tf_e2f_indices,
                        row_splits=tf_e2f_rowsplits,
                        validate=False)
                    e2f_ragged_val = tf.RaggedTensor.from_row_splits(
                        values=tf_e2f_data,
                        row_splits=tf_e2f_rowsplits,
                        validate=False)
            # Initialize sparse tensor of fact2ent.
            with tf.device("/cpu:0"):
                tf_f2e_data, tf_f2e_indices, tf_f2e_rowsplits = (
                    search_utils.load_ragged_matrix("fact2ent",
                                                    f2e_checkpoint))
                with tf.name_scope("RaggedConstruction_f2e"):
                    f2e_ragged_ind = tf.RaggedTensor.from_row_splits(
                        values=tf_f2e_indices,
                        row_splits=tf_f2e_rowsplits,
                        validate=False)
                    f2e_ragged_val = tf.RaggedTensor.from_row_splits(
                        values=tf_f2e_data,
                        row_splits=tf_f2e_rowsplits,
                        validate=False)
            # Initialize sparse tensor of fact2fact.
            with tf.device("/cpu:0"):
                tf_f2f_data, tf_f2f_indices, tf_f2f_rowsplits = (
                    search_utils.load_ragged_matrix("fact2fact",
                                                    f2f_checkpoint))
                with tf.name_scope("RaggedConstruction_f2f"):
                    f2f_ragged_ind = tf.RaggedTensor.from_row_splits(
                        values=tf_f2f_indices,
                        row_splits=tf_f2f_rowsplits,
                        validate=False)
                    f2f_ragged_val = tf.RaggedTensor.from_row_splits(
                        values=tf_f2f_data,
                        row_splits=tf_f2f_rowsplits,
                        validate=False)

            total_loss, predictions = create_model_fn(
                bert_config=bert_config,
                qa_config=qa_config,
                fact_mips_config=fact_mips_config,
                is_training=is_training,
                features=features,
                ent2fact_ind=e2f_ragged_ind,
                ent2fact_val=e2f_ragged_val,
                fact2ent_ind=f2e_ragged_ind,
                fact2ent_val=f2e_ragged_val,
                fact2fact_ind=f2f_ragged_ind,
                fact2fact_val=f2f_ragged_val,
                entity_ids=entity_ids,
                entity_mask=entity_mask,
                use_one_hot_embeddings=use_one_hot_embeddings,
                summary_obj=summary_obj,
                num_preds=FLAGS.num_preds,
                is_excluding=FLAGS.is_excluding,
            )

        tvars = tf.trainable_variables()

        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map,
             initialized_variable_names) = get_assignment_map_from_checkpoint(
                 tvars,
                 init_checkpoint,
                 load_only_bert=qa_config.load_only_bert)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            one_mb = tf.constant(1024 * 1024, dtype=tf.int64)
            devices = tf.config.experimental.list_logical_devices("GPU")
            memory_footprints = []
            for device in devices:
                memory_footprint = tf.print(
                    device.name,
                    contrib_memory_stats.MaxBytesInUse() / one_mb, " / ",
                    contrib_memory_stats.BytesLimit() / one_mb)
                memory_footprints.append(memory_footprint)

            with tf.control_dependencies(memory_footprints):
                train_op = create_optimizer(total_loss, learning_rate,
                                            num_train_steps, num_warmup_steps,
                                            use_tpu, False)

            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.PREDICT:
            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
        else:
            raise ValueError("Only TRAIN and PREDICT modes are supported: %s" %
                             (mode))

        return output_spec
Exemplo n.º 2
0
  def train(self, sess):
    """Main training function/loop.

    Args:
      sess: a tf session object
    """
    # For debugging/pushing limits of model
    gpu_mb = tf.constant(1024*1024, dtype=tf.int64)
    gpus = tf.config.experimental.list_logical_devices("GPU")
    memory_footprints = []
    for gpu in gpus:
      with tf.device(gpu.name):
        memory_footprint = tf.Print(
            tf.constant(0), [
                contrib_memory_stats.BytesLimit() / gpu_mb,
                contrib_memory_stats.MaxBytesInUse() / gpu_mb
            ],
            message=gpu.name)
      memory_footprints.append(memory_footprint)

    epochs = FLAGS.num_epochs
    prints = FLAGS.log_frequency

    training_start_time = time.time()
    epochs_start_time = time.time()

    num_batches = max(int(len(self.train_examples)/self.batch_size), 1)
    tf.logging.info("Num batches per epoch: {}".format(num_batches))

    # Additional logging
    losses = np.zeros((epochs * num_batches))
    accuracies = np.zeros((epochs * num_batches))

    for epoch in range(epochs):
      random.shuffle(self.train_examples)
      for batch in range(num_batches):
        batch_no = epoch * num_batches + batch
        should_sample = (batch_no % prints == 0)

        train_ops_to_run = {
            "train_step": self.train_step,
            "loss": self.model.loss,
            "accuracy": self.model.accuracy,
            "accuracy_per_example": self.model.accuracy_per_ex,
            "output_relations": self.model.log_decoded_relations,
        }
        if should_sample:
          train_ops_to_run["props"] = self.model.property_loss
          train_ops_to_run["regularization"] = self.model.regularization
          for i, memory_footprint in enumerate(memory_footprints):
            train_ops_to_run["memory_footprint_{}".format(i)] = memory_footprint

        batch_examples = self.train_examples[batch:
                                             batch + self.batch_size]
        feed_dict = self._compute_feed_dict(batch_examples)
        train_output = sess.run(train_ops_to_run, feed_dict)
        losses[batch_no] = train_output["loss"]
        accuracies[batch_no] = train_output["accuracy"]

        if should_sample:
          # Timing info
          epochs_end_time = time.time()
          epochs_time_str = str(datetime.timedelta(
              seconds=epochs_end_time - epochs_start_time))
          epochs_start_time = epochs_end_time
          precision, recall = self._evaluate_sample(sess,
                                                    train_output,
                                                    feed_dict,
                                                    batch_examples,
                                                    full_log=True)
          if precision and recall:
            pr_string = "\tPrecision: {:.3f}\tRecall {:.3f}".format(
                np.mean(precision), np.mean(recall))
          else:
            pr_string = ""
          tf.logging.info(
              ("[{}] Epoch: {}.{}\tLoss: {:.3f}|{:.3f}|{:.3f}\t" +
               "Accuracy: {:.3f}{}\n").format(
                   epochs_time_str,
                   epoch, batch,
                   train_output["loss"],
                   train_output["props"],
                   train_output["regularization"],
                   train_output["accuracy"],
                   pr_string))

          # Do a dev run, it doesn't take that long
          self.evaluate(sess, full=False)

    training_end_time = time.time()
    tf.logging.info("Training took: %s" % str(datetime.timedelta(
        seconds=training_end_time - training_start_time)))
    if self.ckpt_dir is not None:
      save_path = self.saver.save(sess,
                                  os.path.join(self.ckpt_dir, "model.ckpt"))
      tf.logging.info("Saved model at {}".format(save_path))