Ejemplo n.º 1
0
 def load_entities(self, base_dir):
     """Load entity ids and masks."""
     tf.reset_default_graph()
     id_ckpt = os.path.join(base_dir, "entity_ids")
     entity_ids = search_utils.load_database("entity_ids",
                                             None,
                                             id_ckpt,
                                             dtype=tf.int32)
     mask_ckpt = os.path.join(base_dir, "entity_mask")
     entity_mask = search_utils.load_database("entity_mask", None,
                                              mask_ckpt)
     with tf.Session() as sess:
         sess.run(tf.global_variables_initializer())
         sess.run(tf.local_variables_initializer())
         tf.logging.info("Loading entity ids and masks...")
         np_ent_ids, np_ent_mask = sess.run([entity_ids, entity_mask])
     tf.logging.info("Building entity count matrix...")
     entity_count_matrix = search_utils.build_count_matrix(
         np_ent_ids, np_ent_mask)
     tf.logging.info("Computing IDFs...")
     self.idfs = search_utils.counts_to_idfs(entity_count_matrix,
                                             cutoff=1e-5)
     tf.logging.info("Computing entity Tf-IDFs...")
     ent_tfidfs = search_utils.counts_to_tfidf(entity_count_matrix,
                                               self.idfs)
     self.ent_tfidfs = normalize(ent_tfidfs, norm="l2", axis=0)
Ejemplo n.º 2
0
def load_entity_matrices(base_dir):
  """Load entity co-occurrence and co-reference matrices."""
  cooccur_ckpt = os.path.join(base_dir, "ent2ment.npz")
  coref_ckpt = os.path.join(base_dir, "coref.npz")

  tf.reset_default_graph()
  co_data, co_indices, co_rowsplits = search_utils.load_ragged_matrix(
      "ent2ment", cooccur_ckpt)
  coref_map = search_utils.load_database(
      "coref", None, coref_ckpt, dtype=tf.int32)

  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    tf.logging.info("Loading ragged matrix...")
    np_data, np_indices, np_indptr = sess.run(
        [co_data, co_indices, co_rowsplits])
    tf.logging.info("Loading coref map...")
    np_coref = sess.run(coref_map)
    num_entities = np_indptr.shape[0] - 1
    num_mentions = np_coref.shape[0]
    tf.logging.info("Creating sparse matrix %d x %d...", num_entities,
                    num_mentions)
    sp_cooccur = sp.csr_matrix((np_data, np_indices, np_indptr),
                               shape=(num_entities, num_mentions))
    tf.logging.info("Creating sparse matrix %d x %d...", num_mentions,
                    num_entities)
    sp_coref = sp.csr_matrix((np.ones_like(np_coref, dtype=np.int32),
                              (np.arange(np_coref.shape[0]), np_coref)),
                             shape=(num_mentions, num_entities))

  metadata_file = os.path.join(base_dir, "entities.json")
  entity2id, _ = json.load(tf.gfile.Open(metadata_file))

  return sp_cooccur, sp_coref, entity2id
Ejemplo n.º 3
0
def do_combine(do_embed_mentions, do_embed_facts, embed_prefix):
    """Combines sharded DB into one single file."""
    if FLAGS.shards_to_combine is None:
        shard_range = range(FLAGS.num_shards + 1)
    else:
        shard_range = range(FLAGS.shards_to_combine)
    if do_embed_mentions:
        db_emb_str = "db_emb"
    elif do_embed_facts:
        db_emb_str = "fact_db_emb"
    with tf.device("/cpu:0"):
        all_db = []
        for i in shard_range:
            if do_embed_mentions:
                embed_str = "%s_mention_feats_%d" % (embed_prefix, i)
            elif do_embed_facts:
                embed_str = "%s_fact_feats_%d" % (embed_prefix, i)
            else:
                tf.logging.info("Error choice")
                return
            ckpt_path = os.path.join(FLAGS.multihop_output_dir, embed_str)
            if not tf.gfile.Exists(ckpt_path + ".meta"):
                tf.logging.info("%s does not exist", ckpt_path)
                continue
            reader = tf.train.NewCheckpointReader(ckpt_path)
            var_to_shape_map = reader.get_variable_to_shape_map()
            tf.logging.info("Reading %s from %s with shape %s",
                            db_emb_str + "_%d" % i, ckpt_path,
                            str(var_to_shape_map[db_emb_str + "_%d" % i]))
            tf_db = search_utils.load_database(
                db_emb_str + "_%d" % i,
                var_to_shape_map[db_emb_str + "_%d" % i], ckpt_path)
            all_db.append(tf_db)
        tf.logging.info("Reading all variables.")
        session = tf.Session()
        session.run(tf.global_variables_initializer())
        session.run(tf.local_variables_initializer())
        np_db = session.run(all_db)
        tf.logging.info("Concatenating and storing.")
        np_db = np.concatenate(np_db, axis=0)
        if do_embed_mentions:
            embed_feats_str = "%s_mention_feats" % embed_prefix
        elif do_embed_facts:
            embed_feats_str = "%s_fact_feats" % embed_prefix
        search_utils.write_to_checkpoint(
            db_emb_str, np_db, tf.float32,
            os.path.join(FLAGS.multihop_output_dir, embed_feats_str))
Ejemplo n.º 4
0
    def model_fn(features, labels, mode, params):
        """The `model_fn` for TPUEstimator."""
        del labels, params  # Not used.
        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s", name,
                            features[name].shape)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        entity_ids = search_utils.load_database(
            "entity_ids", [qa_config.num_entities, qa_config.max_entity_len],
            entity_id_checkpoint,
            dtype=tf.int32)
        entity_mask = search_utils.load_database(
            "entity_mask", [qa_config.num_entities, qa_config.max_entity_len],
            entity_mask_checkpoint)

        if FLAGS.model_type == "drkit":
            # Initialize sparse tensor of ent2ment.
            with tf.device("/cpu:0"):
                tf_e2m_data, tf_e2m_indices, tf_e2m_rowsplits = (
                    search_utils.load_ragged_matrix("ent2ment",
                                                    e2m_checkpoint))
                with tf.name_scope("RaggedConstruction_e2m"):
                    e2m_ragged_ind = tf.RaggedTensor.from_row_splits(
                        values=tf_e2m_indices,
                        row_splits=tf_e2m_rowsplits,
                        validate=False)
                    e2m_ragged_val = tf.RaggedTensor.from_row_splits(
                        values=tf_e2m_data,
                        row_splits=tf_e2m_rowsplits,
                        validate=False)

            tf_m2e_map = search_utils.load_database("coref",
                                                    [mips_config.num_mentions],
                                                    m2e_checkpoint,
                                                    dtype=tf.int32)

            total_loss, predictions = create_model_fn(
                bert_config=bert_config,
                qa_config=qa_config,
                mips_config=mips_config,
                is_training=is_training,
                features=features,
                ent2ment_ind=e2m_ragged_ind,
                ent2ment_val=e2m_ragged_val,
                ment2ent_map=tf_m2e_map,
                entity_ids=entity_ids,
                entity_mask=entity_mask,
                use_one_hot_embeddings=use_one_hot_embeddings,
                summary_obj=summary_obj,
                num_preds=FLAGS.num_preds,
                is_excluding=FLAGS.is_excluding,
            )
        elif FLAGS.model_type == "drfact":
            # Initialize sparse tensor of ent2fact.
            with tf.device("/cpu:0"):  # Note: cpu or gpu?
                tf_e2f_data, tf_e2f_indices, tf_e2f_rowsplits = (
                    search_utils.load_ragged_matrix("ent2fact",
                                                    e2f_checkpoint))
                with tf.name_scope("RaggedConstruction_e2f"):
                    e2f_ragged_ind = tf.RaggedTensor.from_row_splits(
                        values=tf_e2f_indices,
                        row_splits=tf_e2f_rowsplits,
                        validate=False)
                    e2f_ragged_val = tf.RaggedTensor.from_row_splits(
                        values=tf_e2f_data,
                        row_splits=tf_e2f_rowsplits,
                        validate=False)
            # Initialize sparse tensor of fact2ent.
            with tf.device("/cpu:0"):
                tf_f2e_data, tf_f2e_indices, tf_f2e_rowsplits = (
                    search_utils.load_ragged_matrix("fact2ent",
                                                    f2e_checkpoint))
                with tf.name_scope("RaggedConstruction_f2e"):
                    f2e_ragged_ind = tf.RaggedTensor.from_row_splits(
                        values=tf_f2e_indices,
                        row_splits=tf_f2e_rowsplits,
                        validate=False)
                    f2e_ragged_val = tf.RaggedTensor.from_row_splits(
                        values=tf_f2e_data,
                        row_splits=tf_f2e_rowsplits,
                        validate=False)
            # Initialize sparse tensor of fact2fact.
            with tf.device("/cpu:0"):
                tf_f2f_data, tf_f2f_indices, tf_f2f_rowsplits = (
                    search_utils.load_ragged_matrix("fact2fact",
                                                    f2f_checkpoint))
                with tf.name_scope("RaggedConstruction_f2f"):
                    f2f_ragged_ind = tf.RaggedTensor.from_row_splits(
                        values=tf_f2f_indices,
                        row_splits=tf_f2f_rowsplits,
                        validate=False)
                    f2f_ragged_val = tf.RaggedTensor.from_row_splits(
                        values=tf_f2f_data,
                        row_splits=tf_f2f_rowsplits,
                        validate=False)

            total_loss, predictions = create_model_fn(
                bert_config=bert_config,
                qa_config=qa_config,
                fact_mips_config=fact_mips_config,
                is_training=is_training,
                features=features,
                ent2fact_ind=e2f_ragged_ind,
                ent2fact_val=e2f_ragged_val,
                fact2ent_ind=f2e_ragged_ind,
                fact2ent_val=f2e_ragged_val,
                fact2fact_ind=f2f_ragged_ind,
                fact2fact_val=f2f_ragged_val,
                entity_ids=entity_ids,
                entity_mask=entity_mask,
                use_one_hot_embeddings=use_one_hot_embeddings,
                summary_obj=summary_obj,
                num_preds=FLAGS.num_preds,
                is_excluding=FLAGS.is_excluding,
            )

        tvars = tf.trainable_variables()

        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map,
             initialized_variable_names) = get_assignment_map_from_checkpoint(
                 tvars,
                 init_checkpoint,
                 load_only_bert=qa_config.load_only_bert)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            one_mb = tf.constant(1024 * 1024, dtype=tf.int64)
            devices = tf.config.experimental.list_logical_devices("GPU")
            memory_footprints = []
            for device in devices:
                memory_footprint = tf.print(
                    device.name,
                    contrib_memory_stats.MaxBytesInUse() / one_mb, " / ",
                    contrib_memory_stats.BytesLimit() / one_mb)
                memory_footprints.append(memory_footprint)

            with tf.control_dependencies(memory_footprints):
                train_op = create_optimizer(total_loss, learning_rate,
                                            num_train_steps, num_warmup_steps,
                                            use_tpu, False)

            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.PREDICT:
            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
        else:
            raise ValueError("Only TRAIN and PREDICT modes are supported: %s" %
                             (mode))

        return output_spec
Ejemplo n.º 5
0
def main(_):
    if not tf.gfile.Exists(FLAGS.multihop_output_dir):
        tf.gfile.MakeDirs(FLAGS.multihop_output_dir)

    # Initialize tokenizer.
    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    # Read entities.
    if FLAGS.do_preprocess:
        tf.logging.info("Reading entities.")
        entity2id, entity2name = {}, {}
        with tf.gfile.Open(FLAGS.entity_file) as f:
            entities = json.load(f)
            tf.logging.info("Read %d entities", len(entities))
            for e, (_, n) in entities.items():
                if e.lower() in entity2id:
                    continue
                    # tf.logging.warn("%s entity repeated", e)
                entity2id[e.lower()] = len(entity2id)
                entity2name[e.lower()] = n
            tf.logging.info("Kept %d entities", len(entity2id))

    # Read paragraphs, mentions and entities.
    if FLAGS.do_preprocess:
        mentions = []
        ent_rows, ent_cols, ent_vals = [], [], []
        mention2text = {}
        total_sub_paras = [0]
        all_sub_paras = []
        num_skipped_mentions = 0.
        tf.logging.info("Reading paragraphs from %s", FLAGS.wiki_file)
        with tf.gfile.Open(FLAGS.wiki_file) as f:
            for ii, line in tqdm(enumerate(f)):
                if ii == FLAGS.max_total_paragraphs:
                    tf.logging.info(
                        "Processed maximum number of paragraphs, breaking.")
                    break
                if ii > 0 and ii % 100000 == 0:
                    tf.logging.info("Skipped / Kept mentions = %.3f",
                                    num_skipped_mentions / len(mentions))
                orig_para = json.loads(line.strip())
                if orig_para["kb_id"].lower() not in entity2id:
                    tf.logging.warn("%s not in entities. Skipping %s para",
                                    orig_para["kb_id"], orig_para["title"])
                    continue
                sub_para_objs = _get_sub_paras(orig_para, tokenizer,
                                               FLAGS.max_seq_length,
                                               FLAGS.doc_stride,
                                               total_sub_paras)
                for para_obj in sub_para_objs:
                    # Add mentions from this paragraph.
                    local2global = {}
                    title_entity_mention = None
                    for im, mention in enumerate(
                            para_obj["mentions"]
                        [:FLAGS.max_mentions_per_entity]):
                        if mention["kb_id"].lower() not in entity2id:
                            # tf.logging.warn("%s not in entities. Skipping mention %s",
                            #                 mention["kb_id"], mention["text"])
                            num_skipped_mentions += 1
                            continue
                        mention2text[len(mentions)] = mention["text"]
                        local2global[im] = len(mentions)
                        if mention["kb_id"] == orig_para["kb_id"]:
                            title_entity_mention = len(mentions)
                        mentions.append(
                            (entity2id[mention["kb_id"].lower()],
                             para_obj["id"], mention["start_token"],
                             mention["end_token"]))
                    for im, gm in local2global.items():
                        # entity to mention matrix.
                        ent_rows.append(entity2id[orig_para["kb_id"].lower()])
                        ent_cols.append(gm)
                        ent_vals.append(1.)
                        if title_entity_mention is not None:
                            ent_rows.append(mentions[gm][0])
                            ent_cols.append(title_entity_mention)
                            ent_vals.append(1.)
                    all_sub_paras.append(para_obj["tokens"])
                assert len(all_sub_paras) == total_sub_paras[0], (
                    len(all_sub_paras), total_sub_paras)
        tf.logging.info("Num paragraphs = %d, Num mentions = %d",
                        total_sub_paras[0], len(mentions))
        tf.logging.info("Saving coreference map.")
        search_utils.write_to_checkpoint(
            "coref", np.array([m[0] for m in mentions], dtype=np.int32),
            tf.int32, os.path.join(FLAGS.multihop_output_dir, "coref.npz"))
        tf.logging.info("Creating entity to mentions matrix.")
        sp_entity2mention = sp.csr_matrix(
            (ent_vals, (ent_rows, ent_cols)),
            shape=[len(entity2id), len(mentions)])
        tf.logging.info("Num nonzero = %d", sp_entity2mention.getnnz())
        tf.logging.info("Saving as ragged tensor %s.",
                        str(sp_entity2mention.shape))
        search_utils.write_ragged_to_checkpoint(
            "ent2ment", sp_entity2mention,
            os.path.join(FLAGS.multihop_output_dir, "ent2ment.npz"))
        tf.logging.info("Saving mentions metadata.")
        np.save(
            tf.gfile.Open(
                os.path.join(FLAGS.multihop_output_dir, "mentions.npy"), "w"),
            np.array(mentions, dtype=np.int64))
        json.dump(
            mention2text,
            tf.gfile.Open(
                os.path.join(FLAGS.multihop_output_dir, "mention2text.json"),
                "w"))
        tf.logging.info("Saving entities metadata.")
        json.dump([entity2id, entity2name],
                  tf.gfile.Open(
                      os.path.join(FLAGS.multihop_output_dir, "entities.json"),
                      "w"))
        tf.logging.info("Saving split paragraphs.")
        json.dump(
            all_sub_paras,
            tf.gfile.Open(
                os.path.join(FLAGS.multihop_output_dir, "subparas.json"), "w"))

    # Store entity tokens.
    if FLAGS.do_preprocess:
        tf.logging.info("Processing entities.")
        entity_ids = np.zeros((len(entity2id), FLAGS.max_entity_length),
                              dtype=np.int32)
        entity_mask = np.zeros((len(entity2id), FLAGS.max_entity_length),
                               dtype=np.float32)
        num_exceed_len = 0.
        for entity in tqdm(entity2id):
            ei = entity2id[entity]
            entity_tokens = tokenizer.tokenize(entity2name[entity])
            entity_token_ids = tokenizer.convert_tokens_to_ids(entity_tokens)
            if len(entity_token_ids) > FLAGS.max_entity_length:
                num_exceed_len += 1
                entity_token_ids = entity_token_ids[:FLAGS.max_entity_length]
            entity_ids[ei, :len(entity_token_ids)] = entity_token_ids
            entity_mask[ei, :len(entity_token_ids)] = 1.
        tf.logging.info("Saving %d entity ids. %d exceed max-length of %d.",
                        len(entity2id), num_exceed_len,
                        FLAGS.max_entity_length)
        search_utils.write_to_checkpoint(
            "entity_ids", entity_ids, tf.int32,
            os.path.join(FLAGS.multihop_output_dir, "entity_ids"))
        search_utils.write_to_checkpoint(
            "entity_mask", entity_mask, tf.float32,
            os.path.join(FLAGS.multihop_output_dir, "entity_mask"))

    # Copy BERT checkpoint for future use.
    if FLAGS.do_preprocess:
        tf.logging.info("Copying BERT checkpoint.")
        if tf.gfile.Exists(os.path.join(FLAGS.pretrain_dir,
                                        "best_model.index")):
            bert_ckpt = os.path.join(FLAGS.pretrain_dir, "best_model")
        else:
            bert_ckpt = tf.train.latest_checkpoint(FLAGS.pretrain_dir)
        tf.logging.info("%s.data-00000-of-00001", bert_ckpt)
        tf.gfile.Copy(bert_ckpt + ".data-00000-of-00001",
                      os.path.join(FLAGS.multihop_output_dir,
                                   "bert_init.data-00000-of-00001"),
                      overwrite=True)
        tf.logging.info("%s.index", bert_ckpt)
        tf.gfile.Copy(bert_ckpt + ".index",
                      os.path.join(FLAGS.multihop_output_dir,
                                   "bert_init.index"),
                      overwrite=True)
        tf.logging.info("%s.meta", bert_ckpt)
        tf.gfile.Copy(bert_ckpt + ".meta",
                      os.path.join(FLAGS.multihop_output_dir,
                                   "bert_init.meta"),
                      overwrite=True)

    if FLAGS.do_embed:
        # Get mention embeddings from BERT.
        bert_ckpt = os.path.join(FLAGS.multihop_output_dir, "bert_init")
        if not FLAGS.do_preprocess:
            with tf.gfile.Open(
                    os.path.join(FLAGS.multihop_output_dir,
                                 "mentions.npy")) as f:
                mentions = np.load(f)
            with tf.gfile.Open(
                    os.path.join(FLAGS.multihop_output_dir,
                                 "subparas.json")) as f:
                all_sub_paras = json.load(f)
        tf.logging.info("Computing embeddings for %d mentions over %d paras.",
                        len(mentions), len(all_sub_paras))
        shard_size = len(mentions) // FLAGS.num_shards
        bert_predictor = bert_utils_v2.BERTPredictor(tokenizer, bert_ckpt)
        if FLAGS.my_shard is None:
            shard_range = range(FLAGS.num_shards + 1)
        else:
            shard_range = [FLAGS.my_shard]
        for ns in shard_range:
            min_ = ns * shard_size
            max_ = (ns + 1) * shard_size
            if min_ >= len(mentions):
                break
            if max_ > len(mentions):
                max_ = len(mentions)
            min_subp = mentions[min_][1]
            max_subp = mentions[max_ - 1][1]
            tf.logging.info("Processing shard %d of %d mentions and %d paras.",
                            ns, max_ - min_, max_subp - min_subp + 1)
            para_emb = bert_predictor.get_doc_embeddings(
                all_sub_paras[min_subp:max_subp + 1])
            assert para_emb.shape[2] == 2 * FLAGS.projection_dim
            mention_emb = np.empty((max_ - min_, 2 * bert_predictor.emb_dim),
                                   dtype=np.float32)
            for im, mention in enumerate(mentions[min_:max_]):
                mention_emb[im, :] = np.concatenate([
                    para_emb[mention[1] - min_subp,
                             mention[2], :FLAGS.projection_dim],
                    para_emb[mention[1] - min_subp, mention[3],
                             FLAGS.projection_dim:2 * FLAGS.projection_dim]
                ])
            del para_emb
            tf.logging.info(
                "Saving %d mention features to tensorflow checkpoint.",
                mention_emb.shape[0])
            with tf.device("/cpu:0"):
                search_utils.write_to_checkpoint(
                    "db_emb_%d" % ns, mention_emb, tf.float32,
                    os.path.join(FLAGS.multihop_output_dir,
                                 "mention_feats_%d" % ns))

    if FLAGS.do_combine:
        # Combine sharded DB into one.
        if FLAGS.shards_to_combine is None:
            shard_range = range(FLAGS.num_shards + 1)
        else:
            shard_range = range(FLAGS.shards_to_combine)
        with tf.device("/cpu:0"):
            all_db = []
            for i in shard_range:
                ckpt_path = os.path.join(FLAGS.multihop_output_dir,
                                         "mention_feats_%d" % i)
                reader = tf.NewCheckpointReader(ckpt_path)
                var_to_shape_map = reader.get_variable_to_shape_map()
                tf.logging.info("Reading %s from %s with shape %s",
                                "db_emb_%d" % i, ckpt_path,
                                str(var_to_shape_map["db_emb_%d" % i]))
                tf_db = search_utils.load_database(
                    "db_emb_%d" % i, var_to_shape_map["db_emb_%d" % i],
                    ckpt_path)
                all_db.append(tf_db)
            tf.logging.info("Reading all variables.")
            session = tf.Session()
            session.run(tf.global_variables_initializer())
            session.run(tf.local_variables_initializer())
            np_db = session.run(all_db)
            tf.logging.info("Concatenating and storing.")
            np_db = np.concatenate(np_db, axis=0)
            search_utils.write_to_checkpoint(
                "db_emb", np_db, tf.float32,
                os.path.join(FLAGS.multihop_output_dir, "mention_feats"))
Ejemplo n.º 6
0
def tfidf_linking(questions, base_dir, tokenizer, top_k, batch_size=100):
  """Match questions to entities via Tf-IDF."""
  # Load entity ids and masks.
  tf.reset_default_graph()
  id_ckpt = os.path.join(base_dir, "entity_ids")
  entity_ids = search_utils.load_database(
      "entity_ids", None, id_ckpt, dtype=tf.int32)
  mask_ckpt = os.path.join(base_dir, "entity_mask")
  entity_mask = search_utils.load_database("entity_mask", None, mask_ckpt)
  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    tf.logging.info("Loading entity ids and masks...")
    np_ent_ids, np_ent_mask = sess.run([entity_ids, entity_mask])
  tf.logging.info("Building entity count matrix...")
  entity_count_matrix = search_utils.build_count_matrix(np_ent_ids, np_ent_mask)

  # Tokenize questions and build count matrix.
  tf.logging.info("Tokenizing questions...")
  ques_toks, ques_masks = [], []
  for question in questions:
    toks = tokenizer.tokenize(question["question"])
    tok_ids = tokenizer.convert_tokens_to_ids(toks)
    ques_toks.append(tok_ids)
    ques_masks.append([1 for _ in tok_ids])
  tf.logging.info("Building question count matrix...")
  question_count_matrix = search_utils.build_count_matrix(ques_toks, ques_masks)

  # Tf-IDF.
  tf.logging.info("Computing IDFs...")
  idfs = search_utils.counts_to_idfs(entity_count_matrix, cutoff=1e-5)
  tf.logging.info("Computing entity Tf-IDFs...")
  ent_tfidfs = search_utils.counts_to_tfidf(entity_count_matrix, idfs)
  ent_tfidfs = normalize(ent_tfidfs, norm="l2", axis=0)
  tf.logging.info("Computing question TF-IDFs...")
  qry_tfidfs = search_utils.counts_to_tfidf(question_count_matrix, idfs)
  qry_tfidfs = normalize(qry_tfidfs, norm="l2", axis=0)
  tf.logging.info("Searching...")
  top_doc_indices = np.empty((len(questions), top_k), dtype=np.int32)
  top_doc_distances = np.empty((len(questions), top_k), dtype=np.float32)
  # distances = qry_tfidfs.transpose().dot(ent_tfidfs)
  num_batches = len(questions) // batch_size
  tf.logging.info("Computing distances in %d batches of size %d",
                  num_batches + 1, batch_size)
  for nb in tqdm(range(num_batches + 1)):
    min_ = nb * batch_size
    max_ = (nb + 1) * batch_size
    if min_ >= len(questions):
      break
    if max_ > len(questions):
      max_ = len(questions)
    distances = qry_tfidfs[:, min_:max_].transpose().dot(ent_tfidfs).tocsr()
    for ii in range(min_, max_):
      my_distances = distances[ii - min_, :].tocsr()
      if len(my_distances.data) <= top_k:
        o_sort = np.argsort(-my_distances.data)
        top_doc_indices[ii, :len(o_sort)] = my_distances.indices[o_sort]
        top_doc_distances[ii, :len(o_sort)] = my_distances.data[o_sort]
        top_doc_indices[ii, len(o_sort):] = 0
        top_doc_distances[ii, len(o_sort):] = 0
      else:
        o_sort = np.argpartition(-my_distances.data, top_k)[:top_k]
        top_doc_indices[ii, :] = my_distances.indices[o_sort]
        top_doc_distances[ii, :] = my_distances.data[o_sort]

  # Load entity metadata and conver to kb_id.
  metadata_file = os.path.join(base_dir, "entities.json")
  entity2id, entity2name = json.load(tf.gfile.Open(metadata_file))
  id2entity = {i: e for e, i in entity2id.items()}
  id2name = {i: entity2name[e] for e, i in entity2id.items()}
  mentions = []
  for ii in range(len(questions)):
    my_mentions = []
    for m in range(top_k):
      my_mentions.append({
          "kb_id": id2entity[top_doc_indices[ii, m]],
          "score": str(top_doc_distances[ii, m]),
          "name": id2name[top_doc_indices[ii, m]],
      })
    mentions.append(my_mentions)

  return mentions
Ejemplo n.º 7
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s", name,
                            features[name].shape)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        # Initialize sparse tensors.
        with tf.device("/cpu:0"):
            tf_e2m_data, tf_e2m_indices, tf_e2m_rowsplits = (
                search_utils.load_ragged_matrix("ent2ment", e2m_checkpoint))
            with tf.name_scope("RaggedConstruction"):
                e2m_ragged_ind = tf.RaggedTensor.from_row_splits(
                    values=tf_e2m_indices,
                    row_splits=tf_e2m_rowsplits,
                    validate=False)
                e2m_ragged_val = tf.RaggedTensor.from_row_splits(
                    values=tf_e2m_data,
                    row_splits=tf_e2m_rowsplits,
                    validate=False)

        tf_m2e_map = search_utils.load_database("coref",
                                                [mips_config.num_mentions],
                                                m2e_checkpoint,
                                                dtype=tf.int32)
        entity_ids = search_utils.load_database(
            "entity_ids", [qa_config.num_entities, qa_config.max_entity_len],
            entity_id_checkpoint,
            dtype=tf.int32)
        entity_mask = search_utils.load_database(
            "entity_mask", [qa_config.num_entities, qa_config.max_entity_len],
            entity_mask_checkpoint)

        _, predictions = create_model_fn(
            bert_config=bert_config,
            qa_config=qa_config,
            mips_config=mips_config,
            is_training=is_training,
            features=features,
            ent2ment_ind=e2m_ragged_ind,
            ent2ment_val=e2m_ragged_val,
            ment2ent_map=tf_m2e_map,
            entity_ids=entity_ids,
            entity_mask=entity_mask,
            use_one_hot_embeddings=use_one_hot_embeddings,
            summary_obj=summary_obj)

        tvars = tf.trainable_variables()

        scaffold_fn = None
        if init_checkpoint:
            assignment_map, _ = get_assignment_map_from_checkpoint(
                tvars,
                init_checkpoint,
                load_only_bert=qa_config.load_only_bert)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        output_spec = None
        if mode == tf.estimator.ModeKeys.PREDICT:
            output_spec = contrib_tpu.TPUEstimatorSpec(mode=mode,
                                                       predictions=predictions,
                                                       scaffold_fn=scaffold_fn)
        else:
            raise ValueError("Only PREDICT mode is supported: %s" % (mode))

        return output_spec