コード例 #1
0
    def model_fn(features, labels, mode, params):
        """The `model_fn` for TPUEstimator."""
        del labels, params  # Not used.
        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s", name,
                            features[name].shape)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        entity_ids = search_utils.load_database(
            "entity_ids", [qa_config.num_entities, qa_config.max_entity_len],
            entity_id_checkpoint,
            dtype=tf.int32)
        entity_mask = search_utils.load_database(
            "entity_mask", [qa_config.num_entities, qa_config.max_entity_len],
            entity_mask_checkpoint)

        if FLAGS.model_type == "drkit":
            # Initialize sparse tensor of ent2ment.
            with tf.device("/cpu:0"):
                tf_e2m_data, tf_e2m_indices, tf_e2m_rowsplits = (
                    search_utils.load_ragged_matrix("ent2ment",
                                                    e2m_checkpoint))
                with tf.name_scope("RaggedConstruction_e2m"):
                    e2m_ragged_ind = tf.RaggedTensor.from_row_splits(
                        values=tf_e2m_indices,
                        row_splits=tf_e2m_rowsplits,
                        validate=False)
                    e2m_ragged_val = tf.RaggedTensor.from_row_splits(
                        values=tf_e2m_data,
                        row_splits=tf_e2m_rowsplits,
                        validate=False)

            tf_m2e_map = search_utils.load_database("coref",
                                                    [mips_config.num_mentions],
                                                    m2e_checkpoint,
                                                    dtype=tf.int32)

            total_loss, predictions = create_model_fn(
                bert_config=bert_config,
                qa_config=qa_config,
                mips_config=mips_config,
                is_training=is_training,
                features=features,
                ent2ment_ind=e2m_ragged_ind,
                ent2ment_val=e2m_ragged_val,
                ment2ent_map=tf_m2e_map,
                entity_ids=entity_ids,
                entity_mask=entity_mask,
                use_one_hot_embeddings=use_one_hot_embeddings,
                summary_obj=summary_obj,
                num_preds=FLAGS.num_preds,
                is_excluding=FLAGS.is_excluding,
            )
        elif FLAGS.model_type == "drfact":
            # Initialize sparse tensor of ent2fact.
            with tf.device("/cpu:0"):  # Note: cpu or gpu?
                tf_e2f_data, tf_e2f_indices, tf_e2f_rowsplits = (
                    search_utils.load_ragged_matrix("ent2fact",
                                                    e2f_checkpoint))
                with tf.name_scope("RaggedConstruction_e2f"):
                    e2f_ragged_ind = tf.RaggedTensor.from_row_splits(
                        values=tf_e2f_indices,
                        row_splits=tf_e2f_rowsplits,
                        validate=False)
                    e2f_ragged_val = tf.RaggedTensor.from_row_splits(
                        values=tf_e2f_data,
                        row_splits=tf_e2f_rowsplits,
                        validate=False)
            # Initialize sparse tensor of fact2ent.
            with tf.device("/cpu:0"):
                tf_f2e_data, tf_f2e_indices, tf_f2e_rowsplits = (
                    search_utils.load_ragged_matrix("fact2ent",
                                                    f2e_checkpoint))
                with tf.name_scope("RaggedConstruction_f2e"):
                    f2e_ragged_ind = tf.RaggedTensor.from_row_splits(
                        values=tf_f2e_indices,
                        row_splits=tf_f2e_rowsplits,
                        validate=False)
                    f2e_ragged_val = tf.RaggedTensor.from_row_splits(
                        values=tf_f2e_data,
                        row_splits=tf_f2e_rowsplits,
                        validate=False)
            # Initialize sparse tensor of fact2fact.
            with tf.device("/cpu:0"):
                tf_f2f_data, tf_f2f_indices, tf_f2f_rowsplits = (
                    search_utils.load_ragged_matrix("fact2fact",
                                                    f2f_checkpoint))
                with tf.name_scope("RaggedConstruction_f2f"):
                    f2f_ragged_ind = tf.RaggedTensor.from_row_splits(
                        values=tf_f2f_indices,
                        row_splits=tf_f2f_rowsplits,
                        validate=False)
                    f2f_ragged_val = tf.RaggedTensor.from_row_splits(
                        values=tf_f2f_data,
                        row_splits=tf_f2f_rowsplits,
                        validate=False)

            total_loss, predictions = create_model_fn(
                bert_config=bert_config,
                qa_config=qa_config,
                fact_mips_config=fact_mips_config,
                is_training=is_training,
                features=features,
                ent2fact_ind=e2f_ragged_ind,
                ent2fact_val=e2f_ragged_val,
                fact2ent_ind=f2e_ragged_ind,
                fact2ent_val=f2e_ragged_val,
                fact2fact_ind=f2f_ragged_ind,
                fact2fact_val=f2f_ragged_val,
                entity_ids=entity_ids,
                entity_mask=entity_mask,
                use_one_hot_embeddings=use_one_hot_embeddings,
                summary_obj=summary_obj,
                num_preds=FLAGS.num_preds,
                is_excluding=FLAGS.is_excluding,
            )

        tvars = tf.trainable_variables()

        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map,
             initialized_variable_names) = get_assignment_map_from_checkpoint(
                 tvars,
                 init_checkpoint,
                 load_only_bert=qa_config.load_only_bert)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            one_mb = tf.constant(1024 * 1024, dtype=tf.int64)
            devices = tf.config.experimental.list_logical_devices("GPU")
            memory_footprints = []
            for device in devices:
                memory_footprint = tf.print(
                    device.name,
                    contrib_memory_stats.MaxBytesInUse() / one_mb, " / ",
                    contrib_memory_stats.BytesLimit() / one_mb)
                memory_footprints.append(memory_footprint)

            with tf.control_dependencies(memory_footprints):
                train_op = create_optimizer(total_loss, learning_rate,
                                            num_train_steps, num_warmup_steps,
                                            use_tpu, False)

            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.PREDICT:
            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
        else:
            raise ValueError("Only TRAIN and PREDICT modes are supported: %s" %
                             (mode))

        return output_spec
コード例 #2
0
ファイル: demo.py プロジェクト: yyht/language
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s", name,
                            features[name].shape)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        # Initialize sparse tensors.
        with tf.device("/cpu:0"):
            tf_e2m_data, tf_e2m_indices, tf_e2m_rowsplits = (
                search_utils.load_ragged_matrix("ent2ment", e2m_checkpoint))
            with tf.name_scope("RaggedConstruction"):
                e2m_ragged_ind = tf.RaggedTensor.from_row_splits(
                    values=tf_e2m_indices,
                    row_splits=tf_e2m_rowsplits,
                    validate=False)
                e2m_ragged_val = tf.RaggedTensor.from_row_splits(
                    values=tf_e2m_data,
                    row_splits=tf_e2m_rowsplits,
                    validate=False)

        tf_m2e_map = search_utils.load_database("coref",
                                                [mips_config.num_mentions],
                                                m2e_checkpoint,
                                                dtype=tf.int32)
        entity_ids = search_utils.load_database(
            "entity_ids", [qa_config.num_entities, qa_config.max_entity_len],
            entity_id_checkpoint,
            dtype=tf.int32)
        entity_mask = search_utils.load_database(
            "entity_mask", [qa_config.num_entities, qa_config.max_entity_len],
            entity_mask_checkpoint)

        _, predictions = create_model_fn(
            bert_config=bert_config,
            qa_config=qa_config,
            mips_config=mips_config,
            is_training=is_training,
            features=features,
            ent2ment_ind=e2m_ragged_ind,
            ent2ment_val=e2m_ragged_val,
            ment2ent_map=tf_m2e_map,
            entity_ids=entity_ids,
            entity_mask=entity_mask,
            use_one_hot_embeddings=use_one_hot_embeddings,
            summary_obj=summary_obj)

        tvars = tf.trainable_variables()

        scaffold_fn = None
        if init_checkpoint:
            assignment_map, _ = get_assignment_map_from_checkpoint(
                tvars,
                init_checkpoint,
                load_only_bert=qa_config.load_only_bert)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        output_spec = None
        if mode == tf.estimator.ModeKeys.PREDICT:
            output_spec = contrib_tpu.TPUEstimatorSpec(mode=mode,
                                                       predictions=predictions,
                                                       scaffold_fn=scaffold_fn)
        else:
            raise ValueError("Only PREDICT mode is supported: %s" % (mode))

        return output_spec