def do_combine(do_embed_mentions, do_embed_facts, embed_prefix): """Combines sharded DB into one single file.""" if FLAGS.shards_to_combine is None: shard_range = range(FLAGS.num_shards + 1) else: shard_range = range(FLAGS.shards_to_combine) if do_embed_mentions: db_emb_str = "db_emb" elif do_embed_facts: db_emb_str = "fact_db_emb" with tf.device("/cpu:0"): all_db = [] for i in shard_range: if do_embed_mentions: embed_str = "%s_mention_feats_%d" % (embed_prefix, i) elif do_embed_facts: embed_str = "%s_fact_feats_%d" % (embed_prefix, i) else: tf.logging.info("Error choice") return ckpt_path = os.path.join(FLAGS.multihop_output_dir, embed_str) if not tf.gfile.Exists(ckpt_path + ".meta"): tf.logging.info("%s does not exist", ckpt_path) continue reader = tf.train.NewCheckpointReader(ckpt_path) var_to_shape_map = reader.get_variable_to_shape_map() tf.logging.info("Reading %s from %s with shape %s", db_emb_str + "_%d" % i, ckpt_path, str(var_to_shape_map[db_emb_str + "_%d" % i])) tf_db = search_utils.load_database( db_emb_str + "_%d" % i, var_to_shape_map[db_emb_str + "_%d" % i], ckpt_path) all_db.append(tf_db) tf.logging.info("Reading all variables.") session = tf.Session() session.run(tf.global_variables_initializer()) session.run(tf.local_variables_initializer()) np_db = session.run(all_db) tf.logging.info("Concatenating and storing.") np_db = np.concatenate(np_db, axis=0) if do_embed_mentions: embed_feats_str = "%s_mention_feats" % embed_prefix elif do_embed_facts: embed_feats_str = "%s_fact_feats" % embed_prefix search_utils.write_to_checkpoint( db_emb_str, np_db, tf.float32, os.path.join(FLAGS.multihop_output_dir, embed_feats_str))
def main(_): if not tf.gfile.Exists(FLAGS.multihop_output_dir): tf.gfile.MakeDirs(FLAGS.multihop_output_dir) # Filenames. paragraphs_file = os.path.join(FLAGS.data_dir, "processed_wiki.json") train_file = os.path.join(FLAGS.qry_dir, "train.json") dev_file = os.path.join(FLAGS.qry_dir, "dev.json") test_file = os.path.join(FLAGS.qry_dir, "test.json") entities_file = os.path.join(FLAGS.data_dir, "entities.txt") # Initialize tokenizer. tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) # Read entities. tf.logging.info("Reading entities.") entity2id, entity2name = {}, {} with tf.gfile.Open(entities_file) as f: for ii, line in tqdm(enumerate(f)): entity = line.strip() entity2id[entity] = ii entity2name[entity] = entity # Read paragraphs, mentions and entities. mentions = [] ent_rows, ent_cols, ent_vals = [], [], [] para_rows, para_cols, para_vals = [], [], [] mention2text = {} total_sub_paras = [0] all_sub_paras = [] entity_counts = collections.defaultdict(int) tf.logging.info("Reading paragraphs from %s", paragraphs_file) with tf.gfile.Open(paragraphs_file) as f: for line in tqdm(f): orig_para = json.loads(line.strip()) sub_para_objs = _get_sub_paras(orig_para, tokenizer, FLAGS.max_seq_length, FLAGS.doc_stride, total_sub_paras) for para_obj in sub_para_objs: # Add mentions from this paragraph. my_entities = [] my_mentions = [] for m in para_obj["mentions"]: # Para to mention matrix. para_rows.append(para_obj["id"]) para_cols.append(len(mentions)) para_vals.append(1.) # Create entity to mention sparse connections. my_entities.append(m["kb_id"]) my_mentions.append(len(mentions)) entity_counts[m["kb_id"]] += 1 mention2text[len(mentions)] = m["text"] mentions.append((m["kb_id"], para_obj["id"], m["start_token"], m["end_token"])) for entity in my_entities: ent_rows.append(entity) ent_cols.append(para_obj["id"]) ent_vals.append(1. / len(my_mentions)) all_sub_paras.append(para_obj["tokens"]) assert len(all_sub_paras) == total_sub_paras[0], ( len(all_sub_paras), total_sub_paras) tf.logging.info("Num paragraphs = %d, Num mentions = %d", total_sub_paras[0], len(mentions)) tf.logging.info("Saving coreference map.") search_utils.write_to_checkpoint( "coref", np.array([m[0] for m in mentions], dtype=np.int32), tf.int32, os.path.join(FLAGS.multihop_output_dir, "coref.npz")) tf.logging.info("Creating entity to mentions matrix.") sp_entity2para = sp.csr_matrix((ent_vals, (ent_rows, ent_cols)), shape=[len(entity2id), len(all_sub_paras)]) sp_entity2para_filt = preprocess_utils.filter_sparse_rows( sp_entity2para, FLAGS.max_paragraphs_per_entity) sp_para2ment = sp.csr_matrix((para_vals, (para_rows, para_cols)), shape=[len(all_sub_paras), len(mentions)]) sp_entity2mention = sp_entity2para_filt.dot(sp_para2ment) tf.logging.info("Num nonzero = %d", sp_entity2mention.getnnz()) tf.logging.info("Saving as ragged tensor %s.", str(sp_entity2mention.shape)) search_utils.write_ragged_to_checkpoint( "ent2ment", sp_entity2mention, os.path.join(FLAGS.multihop_output_dir, "ent2ment.npz")) tf.logging.info("Saving mentions metadata.") np.save( tf.gfile.Open(os.path.join(FLAGS.multihop_output_dir, "mentions.npy"), "w"), np.array(mentions, dtype=np.int64)) json.dump( mention2text, tf.gfile.Open( os.path.join(FLAGS.multihop_output_dir, "mention2text.json"), "w")) tf.logging.info("Saving entities metadata.") json.dump([entity2id, entity2name], tf.gfile.Open( os.path.join(FLAGS.multihop_output_dir, "entities.json"), "w")) json.dump( entity_counts, tf.gfile.Open( os.path.join(FLAGS.multihop_output_dir, "entity_counts.json"), "w")) tf.logging.info("Saving split paragraphs.") json.dump( all_sub_paras, tf.gfile.Open(os.path.join(FLAGS.multihop_output_dir, "subparas.json"), "w")) # Store entity tokens. tf.logging.info("Processing entities.") entity_ids = np.zeros((len(entity2id), FLAGS.max_entity_len), dtype=np.int32) entity_mask = np.zeros((len(entity2id), FLAGS.max_entity_len), dtype=np.float32) num_exceed_len = 0. for entity in tqdm(entity2id): ei = entity2id[entity] entity_tokens = tokenizer.tokenize(entity2name[entity]) entity_token_ids = tokenizer.convert_tokens_to_ids(entity_tokens) if len(entity_token_ids) > FLAGS.max_entity_len: num_exceed_len += 1 entity_token_ids = entity_token_ids[:FLAGS.max_entity_len] entity_ids[ei, :len(entity_token_ids)] = entity_token_ids entity_mask[ei, :len(entity_token_ids)] = 1. tf.logging.info( "Saving %d entity ids and mask. %d exceed max-length of %d.", len(entity2id), num_exceed_len, FLAGS.max_entity_len) search_utils.write_to_checkpoint( "entity_ids", entity_ids, tf.int32, os.path.join(FLAGS.multihop_output_dir, "entity_ids")) search_utils.write_to_checkpoint( "entity_mask", entity_mask, tf.float32, os.path.join(FLAGS.multihop_output_dir, "entity_mask")) # Pre-process question files. def _preprocess_qrys(in_file, out_file): tf.logging.info("Working on %s", in_file) with tf.gfile.Open(in_file) as f_in, tf.gfile.Open(out_file, "w") as f_out: for line in f_in: item = json.loads(line.strip()) # Sort entities in ascending order of their frequencies. e_counts = [ entity_counts[e["kb_id"]] for e in item["entities"] ] sorted_i = sorted(enumerate(e_counts), key=lambda x: x[1]) item["entities"] = [item["entities"][ii] for ii, _ in sorted_i] f_out.write(json.dumps(item) + "\n") _preprocess_qrys(train_file, os.path.join(FLAGS.multihop_output_dir, "train.json")) _preprocess_qrys(dev_file, os.path.join(FLAGS.multihop_output_dir, "dev.json")) _preprocess_qrys(test_file, os.path.join(FLAGS.multihop_output_dir, "test.json")) # Copy BERT checkpoint for future use. tf.logging.info("Copying BERT checkpoint.") if tf.gfile.Exists(os.path.join(FLAGS.pretrain_dir, "best_model.index")): bert_ckpt = os.path.join(FLAGS.pretrain_dir, "best_model") else: bert_ckpt = tf.train.latest_checkpoint(FLAGS.pretrain_dir) tf.logging.info("%s.data-00000-of-00001", bert_ckpt) tf.gfile.Copy(bert_ckpt + ".data-00000-of-00001", os.path.join(FLAGS.multihop_output_dir, "bert_init.data-00000-of-00001"), overwrite=True) tf.logging.info("%s.index", bert_ckpt) tf.gfile.Copy(bert_ckpt + ".index", os.path.join(FLAGS.multihop_output_dir, "bert_init.index"), overwrite=True) tf.logging.info("%s.meta", bert_ckpt) tf.gfile.Copy(bert_ckpt + ".meta", os.path.join(FLAGS.multihop_output_dir, "bert_init.meta"), overwrite=True) # Get mention embeddings from BERT. tf.logging.info("Computing mention embeddings for %d paras.", len(all_sub_paras)) bert_predictor = bert_utils.BERTPredictor(tokenizer, bert_ckpt) para_emb = bert_predictor.get_doc_embeddings(all_sub_paras) mention_emb = np.empty((len(mentions), 2 * bert_predictor.emb_dim), dtype=np.float32) for im, mention in enumerate(mentions): mention_emb[im, :] = np.concatenate([ para_emb[mention[1], mention[2], :], para_emb[mention[1], mention[3], :] ]) del para_emb tf.logging.info("Saving %d mention features to tensorflow checkpoint.", mention_emb.shape[0]) with tf.device("/cpu:0"): search_utils.write_to_checkpoint( "db_emb", mention_emb, tf.float32, os.path.join(FLAGS.multihop_output_dir, "mention_feats"))
def do_preprocess(tokenizer): """Loads and processes the data.""" # Read concepts. tf.logging.info("Reading entities.") entity2id = load_concept_vocab(FLAGS.entity_file) entity2name = {concept: concept for (concept, _) in entity2id.items()} # print("print # concepts:", len(entity2id)) tf.logging.info("# concepts: %d", len(entity2id)) if not tf.gfile.Exists(FLAGS.multihop_output_dir): tf.gfile.MakeDirs(FLAGS.multihop_output_dir) # Read paragraphs, mentions and entities. mentions = [] ent_rows, ent_cols, ent_vals = [], [], [] ent2fact_rows, ent2fact_cols, ent2fact_vals = [], [], [] fact2ent_rows, fact2ent_cols, fact2ent_vals = [], [], [] ent2num_facts = collections.defaultdict(lambda: 0) mention2text = {} total_sub_paras = [0] all_sub_paras = [] num_skipped_mentions = 0. tf.logging.info("Reading paragraphs from %s", FLAGS.wiki_file) with tf.gfile.Open(FLAGS.wiki_file) as f: lines = f.read().split("\n") if not lines[-1]: lines = lines[:-1] fact2entity = [] for ii, line in tqdm(enumerate(lines[:]), total=len(lines), desc="preprocessing lines"): if ii == FLAGS.max_total_paragraphs: tf.logging.info( "Processed maximum number of paragraphs, breaking.") break orig_para = json.loads(line.strip()) if orig_para["kb_id"].lower() not in entity2id: tf.logging.info("%s not in entities. Skipping %s para", orig_para["kb_id"], orig_para["title"]) continue sub_para_objs = index_util.get_sub_paras(orig_para, tokenizer, FLAGS.max_seq_length, FLAGS.doc_stride, total_sub_paras) assert len(sub_para_objs) == 1 # each doc is a single paragraph. for para_obj in sub_para_objs: # Add mentions from this paragraph. local2global = {} # title_entity_mention = None assert para_obj["id"] == len(fact2entity) fact2entity.append([]) for im, mention in enumerate( para_obj["mentions"][:FLAGS.max_mentions_per_doc]): if mention["kb_id"].lower() not in entity2id: # tf.logging.info("%s not in entities. Skipping mention %s", # mention["kb_id"], mention["text"]) num_skipped_mentions += 1 continue mention2text[len(mentions)] = mention["text"] # Map the index of a local mention to the global mention id. local2global[im] = len(mentions) # if mention["kb_id"] == orig_para["kb_id"]: # title_entity_mention = len(mentions) # The shape of 'mentions.npy' is thus #mentions * 4. mentions.append( (entity2id[mention["kb_id"].lower()], para_obj["id"], mention["start_token"], mention["end_token"])) # fact_id 2 entity_ids fact2entity[para_obj["id"]].append( entity2id[mention["kb_id"].lower()]) # Note: each pair of mention in this paragraph is recorded. local_mentioned_entity_ids = set( [mentions[gm][0] for _, gm in local2global.items()]) # Creating sparse entries for entity2mention matrix. for _, gm in local2global.items(): for cur_entity_id in local_mentioned_entity_ids: if cur_entity_id != gm: ent_rows.append(cur_entity_id) ent_cols.append(gm) ent_vals.append(1.) # Creating sparse entries for entity2fact matrix. for cur_entity_id in local_mentioned_entity_ids: fact2ent_rows.append(ii) # doc_id fact2ent_cols.append(cur_entity_id) fact2ent_vals.append(1.) # Note: Use tf-idf to limit this in the future. if ent2num_facts[cur_entity_id] >= FLAGS.max_facts_per_entity: # We want to limit the number of the facts in ent2fact. # Otherwise, the init_fact will be huge. continue ent2fact_rows.append(cur_entity_id) ent2fact_cols.append(ii) # doc_id ent2fact_vals.append(1.) ent2num_facts[cur_entity_id] += 1 all_sub_paras.append(para_obj["tokens"]) assert len(all_sub_paras) == total_sub_paras[0], (len(all_sub_paras), total_sub_paras) tf.logging.info("Num paragraphs = %d, Num mentions = %d", total_sub_paras[0], len(mentions)) tf.logging.info("Saving mention2entity coreference map.") search_utils.write_to_checkpoint( "coref", np.array([m[0] for m in mentions], dtype=np.int32), tf.int32, os.path.join(FLAGS.multihop_output_dir, "coref.npz")) tf.logging.info("Creating ent2men matrix with %d entries.", len(ent_vals)) # Fill a zero-inited sparse matrix for entity2mention. sp_entity2mention = sp.csr_matrix((ent_vals, (ent_rows, ent_cols)), shape=[len(entity2id), len(mentions)]) tf.logging.info("Num nonzero in e2m = %d", sp_entity2mention.getnnz()) tf.logging.info("Saving as ragged e2m tensor %s.", str(sp_entity2mention.shape)) search_utils.write_ragged_to_checkpoint( "ent2ment", sp_entity2mention, os.path.join(FLAGS.multihop_output_dir, "ent2ment.npz")) tf.logging.info("Saving mentions metadata.") np.save( tf.gfile.Open(os.path.join(FLAGS.multihop_output_dir, "mentions.npy"), "w"), np.array(mentions, dtype=np.int64)) json.dump( mention2text, tf.gfile.Open( os.path.join(FLAGS.multihop_output_dir, "mention2text.json"), "w")) tf.logging.info("Saving entities metadata.") assert len(lines) == len(all_sub_paras) num_facts = len(all_sub_paras) # Fill a zero-inited sparse matrix for entity2fact. sp_entity2fact = sp.csr_matrix( (ent2fact_vals, (ent2fact_rows, ent2fact_cols)), shape=[len(entity2id), num_facts]) tf.logging.info("Num nonzero in e2f = %d", sp_entity2fact.getnnz()) tf.logging.info("Saving as ragged e2f tensor %s.", str(sp_entity2fact.shape)) search_utils.write_ragged_to_checkpoint( "ent2fact", sp_entity2fact, os.path.join(FLAGS.multihop_output_dir, "ent2fact_%d.npz" % FLAGS.max_facts_per_entity)) # Fill a zero-inited sparse matrix for fact2entity. sp_fact2entity = sp.csr_matrix( (ent2fact_vals, (ent2fact_cols, ent2fact_rows)), # Transpose. shape=[num_facts, len(entity2id)]) tf.logging.info("Num nonzero in f2e = %d", sp_fact2entity.getnnz()) tf.logging.info("Saving as ragged f2e tensor %s.", str(sp_fact2entity.shape)) search_utils.write_ragged_to_checkpoint( "fact2ent", sp_fact2entity, os.path.join(FLAGS.multihop_output_dir, "fact_coref.npz")) json.dump([entity2id, entity2name], tf.gfile.Open( os.path.join(FLAGS.multihop_output_dir, "entities.json"), "w")) tf.logging.info("Saving split paragraphs.") json.dump( all_sub_paras, tf.gfile.Open(os.path.join(FLAGS.multihop_output_dir, "subparas.json"), "w")) # Store entity tokens. tf.logging.info("Processing entities.") entity_ids = np.zeros((len(entity2id), FLAGS.max_entity_length), dtype=np.int32) entity_mask = np.zeros((len(entity2id), FLAGS.max_entity_length), dtype=np.float32) num_exceed_len = 0. for entity in tqdm(entity2id): ei = entity2id[entity] entity_tokens = tokenizer.tokenize(entity2name[entity]) entity_token_ids = tokenizer.convert_tokens_to_ids(entity_tokens) if len(entity_token_ids) > FLAGS.max_entity_length: num_exceed_len += 1 entity_token_ids = entity_token_ids[:FLAGS.max_entity_length] entity_ids[ei, :len(entity_token_ids)] = entity_token_ids entity_mask[ei, :len(entity_token_ids)] = 1. tf.logging.info("Saving %d entity ids. %d exceed max-length of %d.", len(entity2id), num_exceed_len, FLAGS.max_entity_length) search_utils.write_to_checkpoint( "entity_ids", entity_ids, tf.int32, os.path.join(FLAGS.multihop_output_dir, "entity_ids")) search_utils.write_to_checkpoint( "entity_mask", entity_mask, tf.float32, os.path.join(FLAGS.multihop_output_dir, "entity_mask"))
def do_embed(tokenizer, do_embed_mentions, do_embed_facts, embed_prefix): """Gets mention embeddings from BERT.""" # Start Embedding. bert_ckpt = os.path.join(FLAGS.bert_ckpt_dir, FLAGS.ckpt_name) with tf.gfile.Open(os.path.join(FLAGS.multihop_output_dir, "mentions.npy"), "rb") as f: mentions = np.load(f) with tf.gfile.Open(os.path.join(FLAGS.multihop_output_dir, "subparas.json")) as f: all_sub_paras = json.load(f) if do_embed_mentions: tf.logging.info("Computing embeddings for %d mentions over %d paras.", len(mentions), len(all_sub_paras)) shard_size = len(mentions) // FLAGS.num_shards # Note that some FLAGS args are passed to the init function here. tf.logging.info("Loading BERT from %s", bert_ckpt) bert_predictor = bert_utils_v2.BERTPredictor(tokenizer, bert_ckpt) if FLAGS.my_shard is None: shard_range = range(FLAGS.num_shards + 1) else: shard_range = [FLAGS.my_shard] for ns in shard_range: min_ = ns * shard_size max_ = (ns + 1) * shard_size if min_ >= len(mentions): break if max_ > len(mentions): max_ = len(mentions) min_subp = mentions[min_][1] # the start sentence id max_subp = mentions[max_ - 1][1] # the end sentence id tf.logging.info("Processing shard %d of %d mentions and %d paras.", ns, max_ - min_, max_subp - min_subp + 1) # Get the embeddings of all the sentences. # Note: this is always the last layer of the BERT. para_emb = bert_predictor.get_doc_embeddings( all_sub_paras[min_subp:max_subp + 1]) assert para_emb.shape[2] == 2 * FLAGS.projection_dim mention_emb = np.empty((max_ - min_, 2 * bert_predictor.emb_dim), dtype=np.float32) for im, mention in enumerate(mentions[min_:max_]): # mention[1] is the sentence id # mention[2/3] is the start/end index of the token mention_emb[im, :] = np.concatenate([ para_emb[mention[1] - min_subp, mention[2], :FLAGS.projection_dim], para_emb[mention[1] - min_subp, mention[3], FLAGS.projection_dim:2 * FLAGS.projection_dim] ]) del para_emb tf.logging.info( "Saving %d mention features to tensorflow checkpoint.", mention_emb.shape[0]) with tf.device("/cpu:0"): search_utils.write_to_checkpoint( "db_emb_%d" % ns, mention_emb, tf.float32, os.path.join(FLAGS.multihop_output_dir, "%s_mention_feats_%d" % (embed_prefix, ns))) if do_embed_facts: tf.logging.info("Computing embeddings for %d facts with %d mentions.", len(all_sub_paras), len(mentions)) fact2mentions = collections.defaultdict(list) for m in mentions: fact2mentions[int(m[1])].append(m) shard_size = len(all_sub_paras) // FLAGS.num_shards # Note that some FLAGS args are passed to the init function here. bert_predictor = bert_utils_v2.BERTPredictor(tokenizer, bert_ckpt) if FLAGS.my_shard is None: shard_range = range(FLAGS.num_shards + 1) else: shard_range = [FLAGS.my_shard] for ns in shard_range: min_ = ns * shard_size max_ = (ns + 1) * shard_size if min_ >= len(all_sub_paras): break if max_ > len(all_sub_paras): max_ = len(all_sub_paras) min_subp = min_ # the start sentence id max_subp = max_ - 1 # the end sentence id tf.logging.info("Processing shard %d of %d facts and %d paras.", ns, max_ - min_, max_subp - min_subp + 1) # Get the embeddings of all the sentences. para_emb = bert_predictor.get_doc_embeddings( all_sub_paras[min_subp:max_subp + 1]) assert para_emb.shape[2] == 2 * FLAGS.projection_dim fact_emb = np.empty((max_ - min_, 2 * bert_predictor.emb_dim), dtype=np.float32) for ii, _ in enumerate(all_sub_paras[min_:max_]): fact_id = min_ + ii local_mentions = fact2mentions[fact_id] mention_agg_emb = np.empty( (len(local_mentions), 2 * bert_predictor.emb_dim), dtype=np.float32) for jj, m in enumerate(local_mentions): mention_agg_emb[jj, :] = np.concatenate([ para_emb[ii, m[2], :FLAGS.projection_dim], para_emb[ii, m[3], FLAGS.projection_dim:2 * FLAGS.projection_dim] ]) fact_emb[ii, :] = np.mean(mention_agg_emb, axis=0) del para_emb tf.logging.info( "Saving %d fact features to tensorflow checkpoint.", fact_emb.shape[0]) with tf.device("/cpu:0"): search_utils.write_to_checkpoint( "fact_db_emb_%d" % ns, fact_emb, tf.float32, os.path.join(FLAGS.multihop_output_dir, "%s_fact_feats_%d" % (embed_prefix, ns)))
def main(_): if not tf.gfile.Exists(FLAGS.multihop_output_dir): tf.gfile.MakeDirs(FLAGS.multihop_output_dir) # Initialize tokenizer. tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) # Read entities. if FLAGS.do_preprocess: tf.logging.info("Reading entities.") entity2id, entity2name = {}, {} with tf.gfile.Open(FLAGS.entity_file) as f: entities = json.load(f) tf.logging.info("Read %d entities", len(entities)) for e, (_, n) in entities.items(): if e.lower() in entity2id: continue # tf.logging.warn("%s entity repeated", e) entity2id[e.lower()] = len(entity2id) entity2name[e.lower()] = n tf.logging.info("Kept %d entities", len(entity2id)) # Read paragraphs, mentions and entities. if FLAGS.do_preprocess: mentions = [] ent_rows, ent_cols, ent_vals = [], [], [] mention2text = {} total_sub_paras = [0] all_sub_paras = [] num_skipped_mentions = 0. tf.logging.info("Reading paragraphs from %s", FLAGS.wiki_file) with tf.gfile.Open(FLAGS.wiki_file) as f: for ii, line in tqdm(enumerate(f)): if ii == FLAGS.max_total_paragraphs: tf.logging.info( "Processed maximum number of paragraphs, breaking.") break if ii > 0 and ii % 100000 == 0: tf.logging.info("Skipped / Kept mentions = %.3f", num_skipped_mentions / len(mentions)) orig_para = json.loads(line.strip()) if orig_para["kb_id"].lower() not in entity2id: tf.logging.warn("%s not in entities. Skipping %s para", orig_para["kb_id"], orig_para["title"]) continue sub_para_objs = _get_sub_paras(orig_para, tokenizer, FLAGS.max_seq_length, FLAGS.doc_stride, total_sub_paras) for para_obj in sub_para_objs: # Add mentions from this paragraph. local2global = {} title_entity_mention = None for im, mention in enumerate( para_obj["mentions"] [:FLAGS.max_mentions_per_entity]): if mention["kb_id"].lower() not in entity2id: # tf.logging.warn("%s not in entities. Skipping mention %s", # mention["kb_id"], mention["text"]) num_skipped_mentions += 1 continue mention2text[len(mentions)] = mention["text"] local2global[im] = len(mentions) if mention["kb_id"] == orig_para["kb_id"]: title_entity_mention = len(mentions) mentions.append( (entity2id[mention["kb_id"].lower()], para_obj["id"], mention["start_token"], mention["end_token"])) for im, gm in local2global.items(): # entity to mention matrix. ent_rows.append(entity2id[orig_para["kb_id"].lower()]) ent_cols.append(gm) ent_vals.append(1.) if title_entity_mention is not None: ent_rows.append(mentions[gm][0]) ent_cols.append(title_entity_mention) ent_vals.append(1.) all_sub_paras.append(para_obj["tokens"]) assert len(all_sub_paras) == total_sub_paras[0], ( len(all_sub_paras), total_sub_paras) tf.logging.info("Num paragraphs = %d, Num mentions = %d", total_sub_paras[0], len(mentions)) tf.logging.info("Saving coreference map.") search_utils.write_to_checkpoint( "coref", np.array([m[0] for m in mentions], dtype=np.int32), tf.int32, os.path.join(FLAGS.multihop_output_dir, "coref.npz")) tf.logging.info("Creating entity to mentions matrix.") sp_entity2mention = sp.csr_matrix( (ent_vals, (ent_rows, ent_cols)), shape=[len(entity2id), len(mentions)]) tf.logging.info("Num nonzero = %d", sp_entity2mention.getnnz()) tf.logging.info("Saving as ragged tensor %s.", str(sp_entity2mention.shape)) search_utils.write_ragged_to_checkpoint( "ent2ment", sp_entity2mention, os.path.join(FLAGS.multihop_output_dir, "ent2ment.npz")) tf.logging.info("Saving mentions metadata.") np.save( tf.gfile.Open( os.path.join(FLAGS.multihop_output_dir, "mentions.npy"), "w"), np.array(mentions, dtype=np.int64)) json.dump( mention2text, tf.gfile.Open( os.path.join(FLAGS.multihop_output_dir, "mention2text.json"), "w")) tf.logging.info("Saving entities metadata.") json.dump([entity2id, entity2name], tf.gfile.Open( os.path.join(FLAGS.multihop_output_dir, "entities.json"), "w")) tf.logging.info("Saving split paragraphs.") json.dump( all_sub_paras, tf.gfile.Open( os.path.join(FLAGS.multihop_output_dir, "subparas.json"), "w")) # Store entity tokens. if FLAGS.do_preprocess: tf.logging.info("Processing entities.") entity_ids = np.zeros((len(entity2id), FLAGS.max_entity_length), dtype=np.int32) entity_mask = np.zeros((len(entity2id), FLAGS.max_entity_length), dtype=np.float32) num_exceed_len = 0. for entity in tqdm(entity2id): ei = entity2id[entity] entity_tokens = tokenizer.tokenize(entity2name[entity]) entity_token_ids = tokenizer.convert_tokens_to_ids(entity_tokens) if len(entity_token_ids) > FLAGS.max_entity_length: num_exceed_len += 1 entity_token_ids = entity_token_ids[:FLAGS.max_entity_length] entity_ids[ei, :len(entity_token_ids)] = entity_token_ids entity_mask[ei, :len(entity_token_ids)] = 1. tf.logging.info("Saving %d entity ids. %d exceed max-length of %d.", len(entity2id), num_exceed_len, FLAGS.max_entity_length) search_utils.write_to_checkpoint( "entity_ids", entity_ids, tf.int32, os.path.join(FLAGS.multihop_output_dir, "entity_ids")) search_utils.write_to_checkpoint( "entity_mask", entity_mask, tf.float32, os.path.join(FLAGS.multihop_output_dir, "entity_mask")) # Copy BERT checkpoint for future use. if FLAGS.do_preprocess: tf.logging.info("Copying BERT checkpoint.") if tf.gfile.Exists(os.path.join(FLAGS.pretrain_dir, "best_model.index")): bert_ckpt = os.path.join(FLAGS.pretrain_dir, "best_model") else: bert_ckpt = tf.train.latest_checkpoint(FLAGS.pretrain_dir) tf.logging.info("%s.data-00000-of-00001", bert_ckpt) tf.gfile.Copy(bert_ckpt + ".data-00000-of-00001", os.path.join(FLAGS.multihop_output_dir, "bert_init.data-00000-of-00001"), overwrite=True) tf.logging.info("%s.index", bert_ckpt) tf.gfile.Copy(bert_ckpt + ".index", os.path.join(FLAGS.multihop_output_dir, "bert_init.index"), overwrite=True) tf.logging.info("%s.meta", bert_ckpt) tf.gfile.Copy(bert_ckpt + ".meta", os.path.join(FLAGS.multihop_output_dir, "bert_init.meta"), overwrite=True) if FLAGS.do_embed: # Get mention embeddings from BERT. bert_ckpt = os.path.join(FLAGS.multihop_output_dir, "bert_init") if not FLAGS.do_preprocess: with tf.gfile.Open( os.path.join(FLAGS.multihop_output_dir, "mentions.npy")) as f: mentions = np.load(f) with tf.gfile.Open( os.path.join(FLAGS.multihop_output_dir, "subparas.json")) as f: all_sub_paras = json.load(f) tf.logging.info("Computing embeddings for %d mentions over %d paras.", len(mentions), len(all_sub_paras)) shard_size = len(mentions) // FLAGS.num_shards bert_predictor = bert_utils_v2.BERTPredictor(tokenizer, bert_ckpt) if FLAGS.my_shard is None: shard_range = range(FLAGS.num_shards + 1) else: shard_range = [FLAGS.my_shard] for ns in shard_range: min_ = ns * shard_size max_ = (ns + 1) * shard_size if min_ >= len(mentions): break if max_ > len(mentions): max_ = len(mentions) min_subp = mentions[min_][1] max_subp = mentions[max_ - 1][1] tf.logging.info("Processing shard %d of %d mentions and %d paras.", ns, max_ - min_, max_subp - min_subp + 1) para_emb = bert_predictor.get_doc_embeddings( all_sub_paras[min_subp:max_subp + 1]) assert para_emb.shape[2] == 2 * FLAGS.projection_dim mention_emb = np.empty((max_ - min_, 2 * bert_predictor.emb_dim), dtype=np.float32) for im, mention in enumerate(mentions[min_:max_]): mention_emb[im, :] = np.concatenate([ para_emb[mention[1] - min_subp, mention[2], :FLAGS.projection_dim], para_emb[mention[1] - min_subp, mention[3], FLAGS.projection_dim:2 * FLAGS.projection_dim] ]) del para_emb tf.logging.info( "Saving %d mention features to tensorflow checkpoint.", mention_emb.shape[0]) with tf.device("/cpu:0"): search_utils.write_to_checkpoint( "db_emb_%d" % ns, mention_emb, tf.float32, os.path.join(FLAGS.multihop_output_dir, "mention_feats_%d" % ns)) if FLAGS.do_combine: # Combine sharded DB into one. if FLAGS.shards_to_combine is None: shard_range = range(FLAGS.num_shards + 1) else: shard_range = range(FLAGS.shards_to_combine) with tf.device("/cpu:0"): all_db = [] for i in shard_range: ckpt_path = os.path.join(FLAGS.multihop_output_dir, "mention_feats_%d" % i) reader = tf.NewCheckpointReader(ckpt_path) var_to_shape_map = reader.get_variable_to_shape_map() tf.logging.info("Reading %s from %s with shape %s", "db_emb_%d" % i, ckpt_path, str(var_to_shape_map["db_emb_%d" % i])) tf_db = search_utils.load_database( "db_emb_%d" % i, var_to_shape_map["db_emb_%d" % i], ckpt_path) all_db.append(tf_db) tf.logging.info("Reading all variables.") session = tf.Session() session.run(tf.global_variables_initializer()) session.run(tf.local_variables_initializer()) np_db = session.run(all_db) tf.logging.info("Concatenating and storing.") np_db = np.concatenate(np_db, axis=0) search_utils.write_to_checkpoint( "db_emb", np_db, tf.float32, os.path.join(FLAGS.multihop_output_dir, "mention_feats"))