Exemplo n.º 1
0
def do_combine(do_embed_mentions, do_embed_facts, embed_prefix):
    """Combines sharded DB into one single file."""
    if FLAGS.shards_to_combine is None:
        shard_range = range(FLAGS.num_shards + 1)
    else:
        shard_range = range(FLAGS.shards_to_combine)
    if do_embed_mentions:
        db_emb_str = "db_emb"
    elif do_embed_facts:
        db_emb_str = "fact_db_emb"
    with tf.device("/cpu:0"):
        all_db = []
        for i in shard_range:
            if do_embed_mentions:
                embed_str = "%s_mention_feats_%d" % (embed_prefix, i)
            elif do_embed_facts:
                embed_str = "%s_fact_feats_%d" % (embed_prefix, i)
            else:
                tf.logging.info("Error choice")
                return
            ckpt_path = os.path.join(FLAGS.multihop_output_dir, embed_str)
            if not tf.gfile.Exists(ckpt_path + ".meta"):
                tf.logging.info("%s does not exist", ckpt_path)
                continue
            reader = tf.train.NewCheckpointReader(ckpt_path)
            var_to_shape_map = reader.get_variable_to_shape_map()
            tf.logging.info("Reading %s from %s with shape %s",
                            db_emb_str + "_%d" % i, ckpt_path,
                            str(var_to_shape_map[db_emb_str + "_%d" % i]))
            tf_db = search_utils.load_database(
                db_emb_str + "_%d" % i,
                var_to_shape_map[db_emb_str + "_%d" % i], ckpt_path)
            all_db.append(tf_db)
        tf.logging.info("Reading all variables.")
        session = tf.Session()
        session.run(tf.global_variables_initializer())
        session.run(tf.local_variables_initializer())
        np_db = session.run(all_db)
        tf.logging.info("Concatenating and storing.")
        np_db = np.concatenate(np_db, axis=0)
        if do_embed_mentions:
            embed_feats_str = "%s_mention_feats" % embed_prefix
        elif do_embed_facts:
            embed_feats_str = "%s_fact_feats" % embed_prefix
        search_utils.write_to_checkpoint(
            db_emb_str, np_db, tf.float32,
            os.path.join(FLAGS.multihop_output_dir, embed_feats_str))
Exemplo n.º 2
0
def main(_):
    if not tf.gfile.Exists(FLAGS.multihop_output_dir):
        tf.gfile.MakeDirs(FLAGS.multihop_output_dir)

    # Filenames.
    paragraphs_file = os.path.join(FLAGS.data_dir, "processed_wiki.json")
    train_file = os.path.join(FLAGS.qry_dir, "train.json")
    dev_file = os.path.join(FLAGS.qry_dir, "dev.json")
    test_file = os.path.join(FLAGS.qry_dir, "test.json")
    entities_file = os.path.join(FLAGS.data_dir, "entities.txt")

    # Initialize tokenizer.
    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    # Read entities.
    tf.logging.info("Reading entities.")
    entity2id, entity2name = {}, {}
    with tf.gfile.Open(entities_file) as f:
        for ii, line in tqdm(enumerate(f)):
            entity = line.strip()
            entity2id[entity] = ii
            entity2name[entity] = entity

    # Read paragraphs, mentions and entities.
    mentions = []
    ent_rows, ent_cols, ent_vals = [], [], []
    para_rows, para_cols, para_vals = [], [], []
    mention2text = {}
    total_sub_paras = [0]
    all_sub_paras = []
    entity_counts = collections.defaultdict(int)
    tf.logging.info("Reading paragraphs from %s", paragraphs_file)
    with tf.gfile.Open(paragraphs_file) as f:
        for line in tqdm(f):
            orig_para = json.loads(line.strip())
            sub_para_objs = _get_sub_paras(orig_para, tokenizer,
                                           FLAGS.max_seq_length,
                                           FLAGS.doc_stride, total_sub_paras)
            for para_obj in sub_para_objs:
                # Add mentions from this paragraph.
                my_entities = []
                my_mentions = []
                for m in para_obj["mentions"]:
                    # Para to mention matrix.
                    para_rows.append(para_obj["id"])
                    para_cols.append(len(mentions))
                    para_vals.append(1.)
                    # Create entity to mention sparse connections.
                    my_entities.append(m["kb_id"])
                    my_mentions.append(len(mentions))
                    entity_counts[m["kb_id"]] += 1
                    mention2text[len(mentions)] = m["text"]
                    mentions.append((m["kb_id"], para_obj["id"],
                                     m["start_token"], m["end_token"]))
                for entity in my_entities:
                    ent_rows.append(entity)
                    ent_cols.append(para_obj["id"])
                    ent_vals.append(1. / len(my_mentions))
                all_sub_paras.append(para_obj["tokens"])
            assert len(all_sub_paras) == total_sub_paras[0], (
                len(all_sub_paras), total_sub_paras)
    tf.logging.info("Num paragraphs = %d, Num mentions = %d",
                    total_sub_paras[0], len(mentions))
    tf.logging.info("Saving coreference map.")
    search_utils.write_to_checkpoint(
        "coref", np.array([m[0] for m in mentions], dtype=np.int32), tf.int32,
        os.path.join(FLAGS.multihop_output_dir, "coref.npz"))
    tf.logging.info("Creating entity to mentions matrix.")
    sp_entity2para = sp.csr_matrix((ent_vals, (ent_rows, ent_cols)),
                                   shape=[len(entity2id),
                                          len(all_sub_paras)])
    sp_entity2para_filt = preprocess_utils.filter_sparse_rows(
        sp_entity2para, FLAGS.max_paragraphs_per_entity)
    sp_para2ment = sp.csr_matrix((para_vals, (para_rows, para_cols)),
                                 shape=[len(all_sub_paras),
                                        len(mentions)])
    sp_entity2mention = sp_entity2para_filt.dot(sp_para2ment)
    tf.logging.info("Num nonzero = %d", sp_entity2mention.getnnz())
    tf.logging.info("Saving as ragged tensor %s.",
                    str(sp_entity2mention.shape))
    search_utils.write_ragged_to_checkpoint(
        "ent2ment", sp_entity2mention,
        os.path.join(FLAGS.multihop_output_dir, "ent2ment.npz"))
    tf.logging.info("Saving mentions metadata.")
    np.save(
        tf.gfile.Open(os.path.join(FLAGS.multihop_output_dir, "mentions.npy"),
                      "w"), np.array(mentions, dtype=np.int64))
    json.dump(
        mention2text,
        tf.gfile.Open(
            os.path.join(FLAGS.multihop_output_dir, "mention2text.json"), "w"))
    tf.logging.info("Saving entities metadata.")
    json.dump([entity2id, entity2name],
              tf.gfile.Open(
                  os.path.join(FLAGS.multihop_output_dir, "entities.json"),
                  "w"))
    json.dump(
        entity_counts,
        tf.gfile.Open(
            os.path.join(FLAGS.multihop_output_dir, "entity_counts.json"),
            "w"))
    tf.logging.info("Saving split paragraphs.")
    json.dump(
        all_sub_paras,
        tf.gfile.Open(os.path.join(FLAGS.multihop_output_dir, "subparas.json"),
                      "w"))

    # Store entity tokens.
    tf.logging.info("Processing entities.")
    entity_ids = np.zeros((len(entity2id), FLAGS.max_entity_len),
                          dtype=np.int32)
    entity_mask = np.zeros((len(entity2id), FLAGS.max_entity_len),
                           dtype=np.float32)
    num_exceed_len = 0.
    for entity in tqdm(entity2id):
        ei = entity2id[entity]
        entity_tokens = tokenizer.tokenize(entity2name[entity])
        entity_token_ids = tokenizer.convert_tokens_to_ids(entity_tokens)
        if len(entity_token_ids) > FLAGS.max_entity_len:
            num_exceed_len += 1
            entity_token_ids = entity_token_ids[:FLAGS.max_entity_len]
        entity_ids[ei, :len(entity_token_ids)] = entity_token_ids
        entity_mask[ei, :len(entity_token_ids)] = 1.
    tf.logging.info(
        "Saving %d entity ids and mask. %d exceed max-length of %d.",
        len(entity2id), num_exceed_len, FLAGS.max_entity_len)
    search_utils.write_to_checkpoint(
        "entity_ids", entity_ids, tf.int32,
        os.path.join(FLAGS.multihop_output_dir, "entity_ids"))
    search_utils.write_to_checkpoint(
        "entity_mask", entity_mask, tf.float32,
        os.path.join(FLAGS.multihop_output_dir, "entity_mask"))

    # Pre-process question files.
    def _preprocess_qrys(in_file, out_file):
        tf.logging.info("Working on %s", in_file)
        with tf.gfile.Open(in_file) as f_in, tf.gfile.Open(out_file,
                                                           "w") as f_out:
            for line in f_in:
                item = json.loads(line.strip())
                # Sort entities in ascending order of their frequencies.
                e_counts = [
                    entity_counts[e["kb_id"]] for e in item["entities"]
                ]
                sorted_i = sorted(enumerate(e_counts), key=lambda x: x[1])
                item["entities"] = [item["entities"][ii] for ii, _ in sorted_i]
                f_out.write(json.dumps(item) + "\n")

    _preprocess_qrys(train_file,
                     os.path.join(FLAGS.multihop_output_dir, "train.json"))
    _preprocess_qrys(dev_file,
                     os.path.join(FLAGS.multihop_output_dir, "dev.json"))
    _preprocess_qrys(test_file,
                     os.path.join(FLAGS.multihop_output_dir, "test.json"))

    # Copy BERT checkpoint for future use.
    tf.logging.info("Copying BERT checkpoint.")
    if tf.gfile.Exists(os.path.join(FLAGS.pretrain_dir, "best_model.index")):
        bert_ckpt = os.path.join(FLAGS.pretrain_dir, "best_model")
    else:
        bert_ckpt = tf.train.latest_checkpoint(FLAGS.pretrain_dir)
    tf.logging.info("%s.data-00000-of-00001", bert_ckpt)
    tf.gfile.Copy(bert_ckpt + ".data-00000-of-00001",
                  os.path.join(FLAGS.multihop_output_dir,
                               "bert_init.data-00000-of-00001"),
                  overwrite=True)
    tf.logging.info("%s.index", bert_ckpt)
    tf.gfile.Copy(bert_ckpt + ".index",
                  os.path.join(FLAGS.multihop_output_dir, "bert_init.index"),
                  overwrite=True)
    tf.logging.info("%s.meta", bert_ckpt)
    tf.gfile.Copy(bert_ckpt + ".meta",
                  os.path.join(FLAGS.multihop_output_dir, "bert_init.meta"),
                  overwrite=True)

    # Get mention embeddings from BERT.
    tf.logging.info("Computing mention embeddings for %d paras.",
                    len(all_sub_paras))
    bert_predictor = bert_utils.BERTPredictor(tokenizer, bert_ckpt)
    para_emb = bert_predictor.get_doc_embeddings(all_sub_paras)
    mention_emb = np.empty((len(mentions), 2 * bert_predictor.emb_dim),
                           dtype=np.float32)
    for im, mention in enumerate(mentions):
        mention_emb[im, :] = np.concatenate([
            para_emb[mention[1], mention[2], :], para_emb[mention[1],
                                                          mention[3], :]
        ])
    del para_emb
    tf.logging.info("Saving %d mention features to tensorflow checkpoint.",
                    mention_emb.shape[0])
    with tf.device("/cpu:0"):
        search_utils.write_to_checkpoint(
            "db_emb", mention_emb, tf.float32,
            os.path.join(FLAGS.multihop_output_dir, "mention_feats"))
Exemplo n.º 3
0
def do_preprocess(tokenizer):
    """Loads and processes the data."""
    # Read concepts.
    tf.logging.info("Reading entities.")
    entity2id = load_concept_vocab(FLAGS.entity_file)
    entity2name = {concept: concept for (concept, _) in entity2id.items()}
    # print("print # concepts:", len(entity2id))
    tf.logging.info("# concepts: %d", len(entity2id))

    if not tf.gfile.Exists(FLAGS.multihop_output_dir):
        tf.gfile.MakeDirs(FLAGS.multihop_output_dir)
    # Read paragraphs, mentions and entities.

    mentions = []
    ent_rows, ent_cols, ent_vals = [], [], []
    ent2fact_rows, ent2fact_cols, ent2fact_vals = [], [], []
    fact2ent_rows, fact2ent_cols, fact2ent_vals = [], [], []
    ent2num_facts = collections.defaultdict(lambda: 0)
    mention2text = {}
    total_sub_paras = [0]
    all_sub_paras = []
    num_skipped_mentions = 0.

    tf.logging.info("Reading paragraphs from %s", FLAGS.wiki_file)
    with tf.gfile.Open(FLAGS.wiki_file) as f:
        lines = f.read().split("\n")
    if not lines[-1]:
        lines = lines[:-1]
    fact2entity = []
    for ii, line in tqdm(enumerate(lines[:]),
                         total=len(lines),
                         desc="preprocessing lines"):
        if ii == FLAGS.max_total_paragraphs:
            tf.logging.info(
                "Processed maximum number of paragraphs, breaking.")
            break
        orig_para = json.loads(line.strip())
        if orig_para["kb_id"].lower() not in entity2id:
            tf.logging.info("%s not in entities. Skipping %s para",
                            orig_para["kb_id"], orig_para["title"])
            continue
        sub_para_objs = index_util.get_sub_paras(orig_para, tokenizer,
                                                 FLAGS.max_seq_length,
                                                 FLAGS.doc_stride,
                                                 total_sub_paras)
        assert len(sub_para_objs) == 1  # each doc is a single paragraph.
        for para_obj in sub_para_objs:
            # Add mentions from this paragraph.
            local2global = {}
            # title_entity_mention = None
            assert para_obj["id"] == len(fact2entity)
            fact2entity.append([])
            for im, mention in enumerate(
                    para_obj["mentions"][:FLAGS.max_mentions_per_doc]):
                if mention["kb_id"].lower() not in entity2id:
                    # tf.logging.info("%s not in entities. Skipping mention %s",
                    #                 mention["kb_id"], mention["text"])
                    num_skipped_mentions += 1
                    continue
                mention2text[len(mentions)] = mention["text"]
                # Map the index of a local mention to the global mention id.
                local2global[im] = len(mentions)
                # if mention["kb_id"] == orig_para["kb_id"]:
                #   title_entity_mention = len(mentions)
                # The shape of 'mentions.npy' is thus #mentions * 4.
                mentions.append(
                    (entity2id[mention["kb_id"].lower()], para_obj["id"],
                     mention["start_token"], mention["end_token"]))
                # fact_id 2 entity_ids
                fact2entity[para_obj["id"]].append(
                    entity2id[mention["kb_id"].lower()])
            # Note: each pair of mention in this paragraph is recorded.
            local_mentioned_entity_ids = set(
                [mentions[gm][0] for _, gm in local2global.items()])

            # Creating sparse entries for entity2mention matrix.
            for _, gm in local2global.items():
                for cur_entity_id in local_mentioned_entity_ids:
                    if cur_entity_id != gm:
                        ent_rows.append(cur_entity_id)
                        ent_cols.append(gm)
                        ent_vals.append(1.)

            # Creating sparse entries for entity2fact matrix.
            for cur_entity_id in local_mentioned_entity_ids:
                fact2ent_rows.append(ii)  # doc_id
                fact2ent_cols.append(cur_entity_id)
                fact2ent_vals.append(1.)
                # Note: Use tf-idf to limit this in the future.
                if ent2num_facts[cur_entity_id] >= FLAGS.max_facts_per_entity:
                    # We want to limit the number of the facts in ent2fact.
                    # Otherwise, the init_fact will be huge.
                    continue
                ent2fact_rows.append(cur_entity_id)
                ent2fact_cols.append(ii)  # doc_id
                ent2fact_vals.append(1.)
                ent2num_facts[cur_entity_id] += 1

            all_sub_paras.append(para_obj["tokens"])
        assert len(all_sub_paras) == total_sub_paras[0], (len(all_sub_paras),
                                                          total_sub_paras)

    tf.logging.info("Num paragraphs = %d, Num mentions = %d",
                    total_sub_paras[0], len(mentions))

    tf.logging.info("Saving mention2entity coreference map.")
    search_utils.write_to_checkpoint(
        "coref", np.array([m[0] for m in mentions], dtype=np.int32), tf.int32,
        os.path.join(FLAGS.multihop_output_dir, "coref.npz"))

    tf.logging.info("Creating ent2men matrix with %d entries.", len(ent_vals))
    # Fill a zero-inited sparse matrix for entity2mention.
    sp_entity2mention = sp.csr_matrix((ent_vals, (ent_rows, ent_cols)),
                                      shape=[len(entity2id),
                                             len(mentions)])
    tf.logging.info("Num nonzero in e2m = %d", sp_entity2mention.getnnz())
    tf.logging.info("Saving as ragged e2m tensor %s.",
                    str(sp_entity2mention.shape))
    search_utils.write_ragged_to_checkpoint(
        "ent2ment", sp_entity2mention,
        os.path.join(FLAGS.multihop_output_dir, "ent2ment.npz"))
    tf.logging.info("Saving mentions metadata.")
    np.save(
        tf.gfile.Open(os.path.join(FLAGS.multihop_output_dir, "mentions.npy"),
                      "w"), np.array(mentions, dtype=np.int64))
    json.dump(
        mention2text,
        tf.gfile.Open(
            os.path.join(FLAGS.multihop_output_dir, "mention2text.json"), "w"))
    tf.logging.info("Saving entities metadata.")

    assert len(lines) == len(all_sub_paras)
    num_facts = len(all_sub_paras)
    # Fill a zero-inited sparse matrix for entity2fact.
    sp_entity2fact = sp.csr_matrix(
        (ent2fact_vals, (ent2fact_rows, ent2fact_cols)),
        shape=[len(entity2id), num_facts])
    tf.logging.info("Num nonzero in e2f = %d", sp_entity2fact.getnnz())
    tf.logging.info("Saving as ragged e2f tensor %s.",
                    str(sp_entity2fact.shape))
    search_utils.write_ragged_to_checkpoint(
        "ent2fact", sp_entity2fact,
        os.path.join(FLAGS.multihop_output_dir,
                     "ent2fact_%d.npz" % FLAGS.max_facts_per_entity))

    # Fill a zero-inited sparse matrix for fact2entity.
    sp_fact2entity = sp.csr_matrix(
        (ent2fact_vals, (ent2fact_cols, ent2fact_rows)),  # Transpose.
        shape=[num_facts, len(entity2id)])
    tf.logging.info("Num nonzero in f2e = %d", sp_fact2entity.getnnz())
    tf.logging.info("Saving as ragged f2e tensor %s.",
                    str(sp_fact2entity.shape))
    search_utils.write_ragged_to_checkpoint(
        "fact2ent", sp_fact2entity,
        os.path.join(FLAGS.multihop_output_dir, "fact_coref.npz"))

    json.dump([entity2id, entity2name],
              tf.gfile.Open(
                  os.path.join(FLAGS.multihop_output_dir, "entities.json"),
                  "w"))
    tf.logging.info("Saving split paragraphs.")
    json.dump(
        all_sub_paras,
        tf.gfile.Open(os.path.join(FLAGS.multihop_output_dir, "subparas.json"),
                      "w"))

    # Store entity tokens.
    tf.logging.info("Processing entities.")
    entity_ids = np.zeros((len(entity2id), FLAGS.max_entity_length),
                          dtype=np.int32)
    entity_mask = np.zeros((len(entity2id), FLAGS.max_entity_length),
                           dtype=np.float32)
    num_exceed_len = 0.
    for entity in tqdm(entity2id):
        ei = entity2id[entity]
        entity_tokens = tokenizer.tokenize(entity2name[entity])
        entity_token_ids = tokenizer.convert_tokens_to_ids(entity_tokens)
        if len(entity_token_ids) > FLAGS.max_entity_length:
            num_exceed_len += 1
            entity_token_ids = entity_token_ids[:FLAGS.max_entity_length]
        entity_ids[ei, :len(entity_token_ids)] = entity_token_ids
        entity_mask[ei, :len(entity_token_ids)] = 1.
    tf.logging.info("Saving %d entity ids. %d exceed max-length of %d.",
                    len(entity2id), num_exceed_len, FLAGS.max_entity_length)
    search_utils.write_to_checkpoint(
        "entity_ids", entity_ids, tf.int32,
        os.path.join(FLAGS.multihop_output_dir, "entity_ids"))
    search_utils.write_to_checkpoint(
        "entity_mask", entity_mask, tf.float32,
        os.path.join(FLAGS.multihop_output_dir, "entity_mask"))
Exemplo n.º 4
0
def do_embed(tokenizer, do_embed_mentions, do_embed_facts, embed_prefix):
    """Gets mention embeddings from BERT."""
    # Start Embedding.
    bert_ckpt = os.path.join(FLAGS.bert_ckpt_dir, FLAGS.ckpt_name)
    with tf.gfile.Open(os.path.join(FLAGS.multihop_output_dir, "mentions.npy"),
                       "rb") as f:
        mentions = np.load(f)
    with tf.gfile.Open(os.path.join(FLAGS.multihop_output_dir,
                                    "subparas.json")) as f:
        all_sub_paras = json.load(f)

    if do_embed_mentions:
        tf.logging.info("Computing embeddings for %d mentions over %d paras.",
                        len(mentions), len(all_sub_paras))
        shard_size = len(mentions) // FLAGS.num_shards
        # Note that some FLAGS args are passed to the init function here.
        tf.logging.info("Loading BERT from %s", bert_ckpt)
        bert_predictor = bert_utils_v2.BERTPredictor(tokenizer, bert_ckpt)
        if FLAGS.my_shard is None:
            shard_range = range(FLAGS.num_shards + 1)
        else:
            shard_range = [FLAGS.my_shard]
        for ns in shard_range:
            min_ = ns * shard_size
            max_ = (ns + 1) * shard_size
            if min_ >= len(mentions):
                break
            if max_ > len(mentions):
                max_ = len(mentions)
            min_subp = mentions[min_][1]  # the start sentence id
            max_subp = mentions[max_ - 1][1]  # the end sentence id
            tf.logging.info("Processing shard %d of %d mentions and %d paras.",
                            ns, max_ - min_, max_subp - min_subp + 1)
            # Get the embeddings of all the sentences.
            # Note: this is always the last layer of the BERT.
            para_emb = bert_predictor.get_doc_embeddings(
                all_sub_paras[min_subp:max_subp + 1])
            assert para_emb.shape[2] == 2 * FLAGS.projection_dim
            mention_emb = np.empty((max_ - min_, 2 * bert_predictor.emb_dim),
                                   dtype=np.float32)
            for im, mention in enumerate(mentions[min_:max_]):
                # mention[1] is the sentence id
                # mention[2/3] is the start/end index of the token
                mention_emb[im, :] = np.concatenate([
                    para_emb[mention[1] - min_subp,
                             mention[2], :FLAGS.projection_dim],
                    para_emb[mention[1] - min_subp, mention[3],
                             FLAGS.projection_dim:2 * FLAGS.projection_dim]
                ])
            del para_emb
            tf.logging.info(
                "Saving %d mention features to tensorflow checkpoint.",
                mention_emb.shape[0])
            with tf.device("/cpu:0"):
                search_utils.write_to_checkpoint(
                    "db_emb_%d" % ns, mention_emb, tf.float32,
                    os.path.join(FLAGS.multihop_output_dir,
                                 "%s_mention_feats_%d" % (embed_prefix, ns)))
    if do_embed_facts:
        tf.logging.info("Computing embeddings for %d facts with %d mentions.",
                        len(all_sub_paras), len(mentions))
        fact2mentions = collections.defaultdict(list)
        for m in mentions:
            fact2mentions[int(m[1])].append(m)
        shard_size = len(all_sub_paras) // FLAGS.num_shards
        # Note that some FLAGS args are passed to the init function here.
        bert_predictor = bert_utils_v2.BERTPredictor(tokenizer, bert_ckpt)
        if FLAGS.my_shard is None:
            shard_range = range(FLAGS.num_shards + 1)
        else:
            shard_range = [FLAGS.my_shard]
        for ns in shard_range:
            min_ = ns * shard_size
            max_ = (ns + 1) * shard_size
            if min_ >= len(all_sub_paras):
                break
            if max_ > len(all_sub_paras):
                max_ = len(all_sub_paras)
            min_subp = min_  # the start sentence id
            max_subp = max_ - 1  # the end sentence id

            tf.logging.info("Processing shard %d of %d facts and %d paras.",
                            ns, max_ - min_, max_subp - min_subp + 1)
            # Get the embeddings of all the sentences.
            para_emb = bert_predictor.get_doc_embeddings(
                all_sub_paras[min_subp:max_subp + 1])
            assert para_emb.shape[2] == 2 * FLAGS.projection_dim
            fact_emb = np.empty((max_ - min_, 2 * bert_predictor.emb_dim),
                                dtype=np.float32)
            for ii, _ in enumerate(all_sub_paras[min_:max_]):
                fact_id = min_ + ii
                local_mentions = fact2mentions[fact_id]

                mention_agg_emb = np.empty(
                    (len(local_mentions), 2 * bert_predictor.emb_dim),
                    dtype=np.float32)
                for jj, m in enumerate(local_mentions):
                    mention_agg_emb[jj, :] = np.concatenate([
                        para_emb[ii, m[2], :FLAGS.projection_dim],
                        para_emb[ii, m[3],
                                 FLAGS.projection_dim:2 * FLAGS.projection_dim]
                    ])

                fact_emb[ii, :] = np.mean(mention_agg_emb, axis=0)
            del para_emb
            tf.logging.info(
                "Saving %d fact features to tensorflow checkpoint.",
                fact_emb.shape[0])
            with tf.device("/cpu:0"):
                search_utils.write_to_checkpoint(
                    "fact_db_emb_%d" % ns, fact_emb, tf.float32,
                    os.path.join(FLAGS.multihop_output_dir,
                                 "%s_fact_feats_%d" % (embed_prefix, ns)))
Exemplo n.º 5
0
def main(_):
    if not tf.gfile.Exists(FLAGS.multihop_output_dir):
        tf.gfile.MakeDirs(FLAGS.multihop_output_dir)

    # Initialize tokenizer.
    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    # Read entities.
    if FLAGS.do_preprocess:
        tf.logging.info("Reading entities.")
        entity2id, entity2name = {}, {}
        with tf.gfile.Open(FLAGS.entity_file) as f:
            entities = json.load(f)
            tf.logging.info("Read %d entities", len(entities))
            for e, (_, n) in entities.items():
                if e.lower() in entity2id:
                    continue
                    # tf.logging.warn("%s entity repeated", e)
                entity2id[e.lower()] = len(entity2id)
                entity2name[e.lower()] = n
            tf.logging.info("Kept %d entities", len(entity2id))

    # Read paragraphs, mentions and entities.
    if FLAGS.do_preprocess:
        mentions = []
        ent_rows, ent_cols, ent_vals = [], [], []
        mention2text = {}
        total_sub_paras = [0]
        all_sub_paras = []
        num_skipped_mentions = 0.
        tf.logging.info("Reading paragraphs from %s", FLAGS.wiki_file)
        with tf.gfile.Open(FLAGS.wiki_file) as f:
            for ii, line in tqdm(enumerate(f)):
                if ii == FLAGS.max_total_paragraphs:
                    tf.logging.info(
                        "Processed maximum number of paragraphs, breaking.")
                    break
                if ii > 0 and ii % 100000 == 0:
                    tf.logging.info("Skipped / Kept mentions = %.3f",
                                    num_skipped_mentions / len(mentions))
                orig_para = json.loads(line.strip())
                if orig_para["kb_id"].lower() not in entity2id:
                    tf.logging.warn("%s not in entities. Skipping %s para",
                                    orig_para["kb_id"], orig_para["title"])
                    continue
                sub_para_objs = _get_sub_paras(orig_para, tokenizer,
                                               FLAGS.max_seq_length,
                                               FLAGS.doc_stride,
                                               total_sub_paras)
                for para_obj in sub_para_objs:
                    # Add mentions from this paragraph.
                    local2global = {}
                    title_entity_mention = None
                    for im, mention in enumerate(
                            para_obj["mentions"]
                        [:FLAGS.max_mentions_per_entity]):
                        if mention["kb_id"].lower() not in entity2id:
                            # tf.logging.warn("%s not in entities. Skipping mention %s",
                            #                 mention["kb_id"], mention["text"])
                            num_skipped_mentions += 1
                            continue
                        mention2text[len(mentions)] = mention["text"]
                        local2global[im] = len(mentions)
                        if mention["kb_id"] == orig_para["kb_id"]:
                            title_entity_mention = len(mentions)
                        mentions.append(
                            (entity2id[mention["kb_id"].lower()],
                             para_obj["id"], mention["start_token"],
                             mention["end_token"]))
                    for im, gm in local2global.items():
                        # entity to mention matrix.
                        ent_rows.append(entity2id[orig_para["kb_id"].lower()])
                        ent_cols.append(gm)
                        ent_vals.append(1.)
                        if title_entity_mention is not None:
                            ent_rows.append(mentions[gm][0])
                            ent_cols.append(title_entity_mention)
                            ent_vals.append(1.)
                    all_sub_paras.append(para_obj["tokens"])
                assert len(all_sub_paras) == total_sub_paras[0], (
                    len(all_sub_paras), total_sub_paras)
        tf.logging.info("Num paragraphs = %d, Num mentions = %d",
                        total_sub_paras[0], len(mentions))
        tf.logging.info("Saving coreference map.")
        search_utils.write_to_checkpoint(
            "coref", np.array([m[0] for m in mentions], dtype=np.int32),
            tf.int32, os.path.join(FLAGS.multihop_output_dir, "coref.npz"))
        tf.logging.info("Creating entity to mentions matrix.")
        sp_entity2mention = sp.csr_matrix(
            (ent_vals, (ent_rows, ent_cols)),
            shape=[len(entity2id), len(mentions)])
        tf.logging.info("Num nonzero = %d", sp_entity2mention.getnnz())
        tf.logging.info("Saving as ragged tensor %s.",
                        str(sp_entity2mention.shape))
        search_utils.write_ragged_to_checkpoint(
            "ent2ment", sp_entity2mention,
            os.path.join(FLAGS.multihop_output_dir, "ent2ment.npz"))
        tf.logging.info("Saving mentions metadata.")
        np.save(
            tf.gfile.Open(
                os.path.join(FLAGS.multihop_output_dir, "mentions.npy"), "w"),
            np.array(mentions, dtype=np.int64))
        json.dump(
            mention2text,
            tf.gfile.Open(
                os.path.join(FLAGS.multihop_output_dir, "mention2text.json"),
                "w"))
        tf.logging.info("Saving entities metadata.")
        json.dump([entity2id, entity2name],
                  tf.gfile.Open(
                      os.path.join(FLAGS.multihop_output_dir, "entities.json"),
                      "w"))
        tf.logging.info("Saving split paragraphs.")
        json.dump(
            all_sub_paras,
            tf.gfile.Open(
                os.path.join(FLAGS.multihop_output_dir, "subparas.json"), "w"))

    # Store entity tokens.
    if FLAGS.do_preprocess:
        tf.logging.info("Processing entities.")
        entity_ids = np.zeros((len(entity2id), FLAGS.max_entity_length),
                              dtype=np.int32)
        entity_mask = np.zeros((len(entity2id), FLAGS.max_entity_length),
                               dtype=np.float32)
        num_exceed_len = 0.
        for entity in tqdm(entity2id):
            ei = entity2id[entity]
            entity_tokens = tokenizer.tokenize(entity2name[entity])
            entity_token_ids = tokenizer.convert_tokens_to_ids(entity_tokens)
            if len(entity_token_ids) > FLAGS.max_entity_length:
                num_exceed_len += 1
                entity_token_ids = entity_token_ids[:FLAGS.max_entity_length]
            entity_ids[ei, :len(entity_token_ids)] = entity_token_ids
            entity_mask[ei, :len(entity_token_ids)] = 1.
        tf.logging.info("Saving %d entity ids. %d exceed max-length of %d.",
                        len(entity2id), num_exceed_len,
                        FLAGS.max_entity_length)
        search_utils.write_to_checkpoint(
            "entity_ids", entity_ids, tf.int32,
            os.path.join(FLAGS.multihop_output_dir, "entity_ids"))
        search_utils.write_to_checkpoint(
            "entity_mask", entity_mask, tf.float32,
            os.path.join(FLAGS.multihop_output_dir, "entity_mask"))

    # Copy BERT checkpoint for future use.
    if FLAGS.do_preprocess:
        tf.logging.info("Copying BERT checkpoint.")
        if tf.gfile.Exists(os.path.join(FLAGS.pretrain_dir,
                                        "best_model.index")):
            bert_ckpt = os.path.join(FLAGS.pretrain_dir, "best_model")
        else:
            bert_ckpt = tf.train.latest_checkpoint(FLAGS.pretrain_dir)
        tf.logging.info("%s.data-00000-of-00001", bert_ckpt)
        tf.gfile.Copy(bert_ckpt + ".data-00000-of-00001",
                      os.path.join(FLAGS.multihop_output_dir,
                                   "bert_init.data-00000-of-00001"),
                      overwrite=True)
        tf.logging.info("%s.index", bert_ckpt)
        tf.gfile.Copy(bert_ckpt + ".index",
                      os.path.join(FLAGS.multihop_output_dir,
                                   "bert_init.index"),
                      overwrite=True)
        tf.logging.info("%s.meta", bert_ckpt)
        tf.gfile.Copy(bert_ckpt + ".meta",
                      os.path.join(FLAGS.multihop_output_dir,
                                   "bert_init.meta"),
                      overwrite=True)

    if FLAGS.do_embed:
        # Get mention embeddings from BERT.
        bert_ckpt = os.path.join(FLAGS.multihop_output_dir, "bert_init")
        if not FLAGS.do_preprocess:
            with tf.gfile.Open(
                    os.path.join(FLAGS.multihop_output_dir,
                                 "mentions.npy")) as f:
                mentions = np.load(f)
            with tf.gfile.Open(
                    os.path.join(FLAGS.multihop_output_dir,
                                 "subparas.json")) as f:
                all_sub_paras = json.load(f)
        tf.logging.info("Computing embeddings for %d mentions over %d paras.",
                        len(mentions), len(all_sub_paras))
        shard_size = len(mentions) // FLAGS.num_shards
        bert_predictor = bert_utils_v2.BERTPredictor(tokenizer, bert_ckpt)
        if FLAGS.my_shard is None:
            shard_range = range(FLAGS.num_shards + 1)
        else:
            shard_range = [FLAGS.my_shard]
        for ns in shard_range:
            min_ = ns * shard_size
            max_ = (ns + 1) * shard_size
            if min_ >= len(mentions):
                break
            if max_ > len(mentions):
                max_ = len(mentions)
            min_subp = mentions[min_][1]
            max_subp = mentions[max_ - 1][1]
            tf.logging.info("Processing shard %d of %d mentions and %d paras.",
                            ns, max_ - min_, max_subp - min_subp + 1)
            para_emb = bert_predictor.get_doc_embeddings(
                all_sub_paras[min_subp:max_subp + 1])
            assert para_emb.shape[2] == 2 * FLAGS.projection_dim
            mention_emb = np.empty((max_ - min_, 2 * bert_predictor.emb_dim),
                                   dtype=np.float32)
            for im, mention in enumerate(mentions[min_:max_]):
                mention_emb[im, :] = np.concatenate([
                    para_emb[mention[1] - min_subp,
                             mention[2], :FLAGS.projection_dim],
                    para_emb[mention[1] - min_subp, mention[3],
                             FLAGS.projection_dim:2 * FLAGS.projection_dim]
                ])
            del para_emb
            tf.logging.info(
                "Saving %d mention features to tensorflow checkpoint.",
                mention_emb.shape[0])
            with tf.device("/cpu:0"):
                search_utils.write_to_checkpoint(
                    "db_emb_%d" % ns, mention_emb, tf.float32,
                    os.path.join(FLAGS.multihop_output_dir,
                                 "mention_feats_%d" % ns))

    if FLAGS.do_combine:
        # Combine sharded DB into one.
        if FLAGS.shards_to_combine is None:
            shard_range = range(FLAGS.num_shards + 1)
        else:
            shard_range = range(FLAGS.shards_to_combine)
        with tf.device("/cpu:0"):
            all_db = []
            for i in shard_range:
                ckpt_path = os.path.join(FLAGS.multihop_output_dir,
                                         "mention_feats_%d" % i)
                reader = tf.NewCheckpointReader(ckpt_path)
                var_to_shape_map = reader.get_variable_to_shape_map()
                tf.logging.info("Reading %s from %s with shape %s",
                                "db_emb_%d" % i, ckpt_path,
                                str(var_to_shape_map["db_emb_%d" % i]))
                tf_db = search_utils.load_database(
                    "db_emb_%d" % i, var_to_shape_map["db_emb_%d" % i],
                    ckpt_path)
                all_db.append(tf_db)
            tf.logging.info("Reading all variables.")
            session = tf.Session()
            session.run(tf.global_variables_initializer())
            session.run(tf.local_variables_initializer())
            np_db = session.run(all_db)
            tf.logging.info("Concatenating and storing.")
            np_db = np.concatenate(np_db, axis=0)
            search_utils.write_to_checkpoint(
                "db_emb", np_db, tf.float32,
                os.path.join(FLAGS.multihop_output_dir, "mention_feats"))