def load_entities(self, base_dir): """Load entity ids and masks.""" tf.reset_default_graph() id_ckpt = os.path.join(base_dir, "entity_ids") entity_ids = search_utils.load_database("entity_ids", None, id_ckpt, dtype=tf.int32) mask_ckpt = os.path.join(base_dir, "entity_mask") entity_mask = search_utils.load_database("entity_mask", None, mask_ckpt) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) tf.logging.info("Loading entity ids and masks...") np_ent_ids, np_ent_mask = sess.run([entity_ids, entity_mask]) tf.logging.info("Building entity count matrix...") entity_count_matrix = search_utils.build_count_matrix( np_ent_ids, np_ent_mask) tf.logging.info("Computing IDFs...") self.idfs = search_utils.counts_to_idfs(entity_count_matrix, cutoff=1e-5) tf.logging.info("Computing entity Tf-IDFs...") ent_tfidfs = search_utils.counts_to_tfidf(entity_count_matrix, self.idfs) self.ent_tfidfs = normalize(ent_tfidfs, norm="l2", axis=0)
def load_entity_matrices(base_dir): """Load entity co-occurrence and co-reference matrices.""" cooccur_ckpt = os.path.join(base_dir, "ent2ment.npz") coref_ckpt = os.path.join(base_dir, "coref.npz") tf.reset_default_graph() co_data, co_indices, co_rowsplits = search_utils.load_ragged_matrix( "ent2ment", cooccur_ckpt) coref_map = search_utils.load_database( "coref", None, coref_ckpt, dtype=tf.int32) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) tf.logging.info("Loading ragged matrix...") np_data, np_indices, np_indptr = sess.run( [co_data, co_indices, co_rowsplits]) tf.logging.info("Loading coref map...") np_coref = sess.run(coref_map) num_entities = np_indptr.shape[0] - 1 num_mentions = np_coref.shape[0] tf.logging.info("Creating sparse matrix %d x %d...", num_entities, num_mentions) sp_cooccur = sp.csr_matrix((np_data, np_indices, np_indptr), shape=(num_entities, num_mentions)) tf.logging.info("Creating sparse matrix %d x %d...", num_mentions, num_entities) sp_coref = sp.csr_matrix((np.ones_like(np_coref, dtype=np.int32), (np.arange(np_coref.shape[0]), np_coref)), shape=(num_mentions, num_entities)) metadata_file = os.path.join(base_dir, "entities.json") entity2id, _ = json.load(tf.gfile.Open(metadata_file)) return sp_cooccur, sp_coref, entity2id
def do_combine(do_embed_mentions, do_embed_facts, embed_prefix): """Combines sharded DB into one single file.""" if FLAGS.shards_to_combine is None: shard_range = range(FLAGS.num_shards + 1) else: shard_range = range(FLAGS.shards_to_combine) if do_embed_mentions: db_emb_str = "db_emb" elif do_embed_facts: db_emb_str = "fact_db_emb" with tf.device("/cpu:0"): all_db = [] for i in shard_range: if do_embed_mentions: embed_str = "%s_mention_feats_%d" % (embed_prefix, i) elif do_embed_facts: embed_str = "%s_fact_feats_%d" % (embed_prefix, i) else: tf.logging.info("Error choice") return ckpt_path = os.path.join(FLAGS.multihop_output_dir, embed_str) if not tf.gfile.Exists(ckpt_path + ".meta"): tf.logging.info("%s does not exist", ckpt_path) continue reader = tf.train.NewCheckpointReader(ckpt_path) var_to_shape_map = reader.get_variable_to_shape_map() tf.logging.info("Reading %s from %s with shape %s", db_emb_str + "_%d" % i, ckpt_path, str(var_to_shape_map[db_emb_str + "_%d" % i])) tf_db = search_utils.load_database( db_emb_str + "_%d" % i, var_to_shape_map[db_emb_str + "_%d" % i], ckpt_path) all_db.append(tf_db) tf.logging.info("Reading all variables.") session = tf.Session() session.run(tf.global_variables_initializer()) session.run(tf.local_variables_initializer()) np_db = session.run(all_db) tf.logging.info("Concatenating and storing.") np_db = np.concatenate(np_db, axis=0) if do_embed_mentions: embed_feats_str = "%s_mention_feats" % embed_prefix elif do_embed_facts: embed_feats_str = "%s_fact_feats" % embed_prefix search_utils.write_to_checkpoint( db_emb_str, np_db, tf.float32, os.path.join(FLAGS.multihop_output_dir, embed_feats_str))
def model_fn(features, labels, mode, params): """The `model_fn` for TPUEstimator.""" del labels, params # Not used. tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s", name, features[name].shape) is_training = (mode == tf.estimator.ModeKeys.TRAIN) entity_ids = search_utils.load_database( "entity_ids", [qa_config.num_entities, qa_config.max_entity_len], entity_id_checkpoint, dtype=tf.int32) entity_mask = search_utils.load_database( "entity_mask", [qa_config.num_entities, qa_config.max_entity_len], entity_mask_checkpoint) if FLAGS.model_type == "drkit": # Initialize sparse tensor of ent2ment. with tf.device("/cpu:0"): tf_e2m_data, tf_e2m_indices, tf_e2m_rowsplits = ( search_utils.load_ragged_matrix("ent2ment", e2m_checkpoint)) with tf.name_scope("RaggedConstruction_e2m"): e2m_ragged_ind = tf.RaggedTensor.from_row_splits( values=tf_e2m_indices, row_splits=tf_e2m_rowsplits, validate=False) e2m_ragged_val = tf.RaggedTensor.from_row_splits( values=tf_e2m_data, row_splits=tf_e2m_rowsplits, validate=False) tf_m2e_map = search_utils.load_database("coref", [mips_config.num_mentions], m2e_checkpoint, dtype=tf.int32) total_loss, predictions = create_model_fn( bert_config=bert_config, qa_config=qa_config, mips_config=mips_config, is_training=is_training, features=features, ent2ment_ind=e2m_ragged_ind, ent2ment_val=e2m_ragged_val, ment2ent_map=tf_m2e_map, entity_ids=entity_ids, entity_mask=entity_mask, use_one_hot_embeddings=use_one_hot_embeddings, summary_obj=summary_obj, num_preds=FLAGS.num_preds, is_excluding=FLAGS.is_excluding, ) elif FLAGS.model_type == "drfact": # Initialize sparse tensor of ent2fact. with tf.device("/cpu:0"): # Note: cpu or gpu? tf_e2f_data, tf_e2f_indices, tf_e2f_rowsplits = ( search_utils.load_ragged_matrix("ent2fact", e2f_checkpoint)) with tf.name_scope("RaggedConstruction_e2f"): e2f_ragged_ind = tf.RaggedTensor.from_row_splits( values=tf_e2f_indices, row_splits=tf_e2f_rowsplits, validate=False) e2f_ragged_val = tf.RaggedTensor.from_row_splits( values=tf_e2f_data, row_splits=tf_e2f_rowsplits, validate=False) # Initialize sparse tensor of fact2ent. with tf.device("/cpu:0"): tf_f2e_data, tf_f2e_indices, tf_f2e_rowsplits = ( search_utils.load_ragged_matrix("fact2ent", f2e_checkpoint)) with tf.name_scope("RaggedConstruction_f2e"): f2e_ragged_ind = tf.RaggedTensor.from_row_splits( values=tf_f2e_indices, row_splits=tf_f2e_rowsplits, validate=False) f2e_ragged_val = tf.RaggedTensor.from_row_splits( values=tf_f2e_data, row_splits=tf_f2e_rowsplits, validate=False) # Initialize sparse tensor of fact2fact. with tf.device("/cpu:0"): tf_f2f_data, tf_f2f_indices, tf_f2f_rowsplits = ( search_utils.load_ragged_matrix("fact2fact", f2f_checkpoint)) with tf.name_scope("RaggedConstruction_f2f"): f2f_ragged_ind = tf.RaggedTensor.from_row_splits( values=tf_f2f_indices, row_splits=tf_f2f_rowsplits, validate=False) f2f_ragged_val = tf.RaggedTensor.from_row_splits( values=tf_f2f_data, row_splits=tf_f2f_rowsplits, validate=False) total_loss, predictions = create_model_fn( bert_config=bert_config, qa_config=qa_config, fact_mips_config=fact_mips_config, is_training=is_training, features=features, ent2fact_ind=e2f_ragged_ind, ent2fact_val=e2f_ragged_val, fact2ent_ind=f2e_ragged_ind, fact2ent_val=f2e_ragged_val, fact2fact_ind=f2f_ragged_ind, fact2fact_val=f2f_ragged_val, entity_ids=entity_ids, entity_mask=entity_mask, use_one_hot_embeddings=use_one_hot_embeddings, summary_obj=summary_obj, num_preds=FLAGS.num_preds, is_excluding=FLAGS.is_excluding, ) tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: (assignment_map, initialized_variable_names) = get_assignment_map_from_checkpoint( tvars, init_checkpoint, load_only_bert=qa_config.load_only_bert) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: one_mb = tf.constant(1024 * 1024, dtype=tf.int64) devices = tf.config.experimental.list_logical_devices("GPU") memory_footprints = [] for device in devices: memory_footprint = tf.print( device.name, contrib_memory_stats.MaxBytesInUse() / one_mb, " / ", contrib_memory_stats.BytesLimit() / one_mb) memory_footprints.append(memory_footprint) with tf.control_dependencies(memory_footprints): train_op = create_optimizer(total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu, False) output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.PREDICT: output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) else: raise ValueError("Only TRAIN and PREDICT modes are supported: %s" % (mode)) return output_spec
def main(_): if not tf.gfile.Exists(FLAGS.multihop_output_dir): tf.gfile.MakeDirs(FLAGS.multihop_output_dir) # Initialize tokenizer. tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) # Read entities. if FLAGS.do_preprocess: tf.logging.info("Reading entities.") entity2id, entity2name = {}, {} with tf.gfile.Open(FLAGS.entity_file) as f: entities = json.load(f) tf.logging.info("Read %d entities", len(entities)) for e, (_, n) in entities.items(): if e.lower() in entity2id: continue # tf.logging.warn("%s entity repeated", e) entity2id[e.lower()] = len(entity2id) entity2name[e.lower()] = n tf.logging.info("Kept %d entities", len(entity2id)) # Read paragraphs, mentions and entities. if FLAGS.do_preprocess: mentions = [] ent_rows, ent_cols, ent_vals = [], [], [] mention2text = {} total_sub_paras = [0] all_sub_paras = [] num_skipped_mentions = 0. tf.logging.info("Reading paragraphs from %s", FLAGS.wiki_file) with tf.gfile.Open(FLAGS.wiki_file) as f: for ii, line in tqdm(enumerate(f)): if ii == FLAGS.max_total_paragraphs: tf.logging.info( "Processed maximum number of paragraphs, breaking.") break if ii > 0 and ii % 100000 == 0: tf.logging.info("Skipped / Kept mentions = %.3f", num_skipped_mentions / len(mentions)) orig_para = json.loads(line.strip()) if orig_para["kb_id"].lower() not in entity2id: tf.logging.warn("%s not in entities. Skipping %s para", orig_para["kb_id"], orig_para["title"]) continue sub_para_objs = _get_sub_paras(orig_para, tokenizer, FLAGS.max_seq_length, FLAGS.doc_stride, total_sub_paras) for para_obj in sub_para_objs: # Add mentions from this paragraph. local2global = {} title_entity_mention = None for im, mention in enumerate( para_obj["mentions"] [:FLAGS.max_mentions_per_entity]): if mention["kb_id"].lower() not in entity2id: # tf.logging.warn("%s not in entities. Skipping mention %s", # mention["kb_id"], mention["text"]) num_skipped_mentions += 1 continue mention2text[len(mentions)] = mention["text"] local2global[im] = len(mentions) if mention["kb_id"] == orig_para["kb_id"]: title_entity_mention = len(mentions) mentions.append( (entity2id[mention["kb_id"].lower()], para_obj["id"], mention["start_token"], mention["end_token"])) for im, gm in local2global.items(): # entity to mention matrix. ent_rows.append(entity2id[orig_para["kb_id"].lower()]) ent_cols.append(gm) ent_vals.append(1.) if title_entity_mention is not None: ent_rows.append(mentions[gm][0]) ent_cols.append(title_entity_mention) ent_vals.append(1.) all_sub_paras.append(para_obj["tokens"]) assert len(all_sub_paras) == total_sub_paras[0], ( len(all_sub_paras), total_sub_paras) tf.logging.info("Num paragraphs = %d, Num mentions = %d", total_sub_paras[0], len(mentions)) tf.logging.info("Saving coreference map.") search_utils.write_to_checkpoint( "coref", np.array([m[0] for m in mentions], dtype=np.int32), tf.int32, os.path.join(FLAGS.multihop_output_dir, "coref.npz")) tf.logging.info("Creating entity to mentions matrix.") sp_entity2mention = sp.csr_matrix( (ent_vals, (ent_rows, ent_cols)), shape=[len(entity2id), len(mentions)]) tf.logging.info("Num nonzero = %d", sp_entity2mention.getnnz()) tf.logging.info("Saving as ragged tensor %s.", str(sp_entity2mention.shape)) search_utils.write_ragged_to_checkpoint( "ent2ment", sp_entity2mention, os.path.join(FLAGS.multihop_output_dir, "ent2ment.npz")) tf.logging.info("Saving mentions metadata.") np.save( tf.gfile.Open( os.path.join(FLAGS.multihop_output_dir, "mentions.npy"), "w"), np.array(mentions, dtype=np.int64)) json.dump( mention2text, tf.gfile.Open( os.path.join(FLAGS.multihop_output_dir, "mention2text.json"), "w")) tf.logging.info("Saving entities metadata.") json.dump([entity2id, entity2name], tf.gfile.Open( os.path.join(FLAGS.multihop_output_dir, "entities.json"), "w")) tf.logging.info("Saving split paragraphs.") json.dump( all_sub_paras, tf.gfile.Open( os.path.join(FLAGS.multihop_output_dir, "subparas.json"), "w")) # Store entity tokens. if FLAGS.do_preprocess: tf.logging.info("Processing entities.") entity_ids = np.zeros((len(entity2id), FLAGS.max_entity_length), dtype=np.int32) entity_mask = np.zeros((len(entity2id), FLAGS.max_entity_length), dtype=np.float32) num_exceed_len = 0. for entity in tqdm(entity2id): ei = entity2id[entity] entity_tokens = tokenizer.tokenize(entity2name[entity]) entity_token_ids = tokenizer.convert_tokens_to_ids(entity_tokens) if len(entity_token_ids) > FLAGS.max_entity_length: num_exceed_len += 1 entity_token_ids = entity_token_ids[:FLAGS.max_entity_length] entity_ids[ei, :len(entity_token_ids)] = entity_token_ids entity_mask[ei, :len(entity_token_ids)] = 1. tf.logging.info("Saving %d entity ids. %d exceed max-length of %d.", len(entity2id), num_exceed_len, FLAGS.max_entity_length) search_utils.write_to_checkpoint( "entity_ids", entity_ids, tf.int32, os.path.join(FLAGS.multihop_output_dir, "entity_ids")) search_utils.write_to_checkpoint( "entity_mask", entity_mask, tf.float32, os.path.join(FLAGS.multihop_output_dir, "entity_mask")) # Copy BERT checkpoint for future use. if FLAGS.do_preprocess: tf.logging.info("Copying BERT checkpoint.") if tf.gfile.Exists(os.path.join(FLAGS.pretrain_dir, "best_model.index")): bert_ckpt = os.path.join(FLAGS.pretrain_dir, "best_model") else: bert_ckpt = tf.train.latest_checkpoint(FLAGS.pretrain_dir) tf.logging.info("%s.data-00000-of-00001", bert_ckpt) tf.gfile.Copy(bert_ckpt + ".data-00000-of-00001", os.path.join(FLAGS.multihop_output_dir, "bert_init.data-00000-of-00001"), overwrite=True) tf.logging.info("%s.index", bert_ckpt) tf.gfile.Copy(bert_ckpt + ".index", os.path.join(FLAGS.multihop_output_dir, "bert_init.index"), overwrite=True) tf.logging.info("%s.meta", bert_ckpt) tf.gfile.Copy(bert_ckpt + ".meta", os.path.join(FLAGS.multihop_output_dir, "bert_init.meta"), overwrite=True) if FLAGS.do_embed: # Get mention embeddings from BERT. bert_ckpt = os.path.join(FLAGS.multihop_output_dir, "bert_init") if not FLAGS.do_preprocess: with tf.gfile.Open( os.path.join(FLAGS.multihop_output_dir, "mentions.npy")) as f: mentions = np.load(f) with tf.gfile.Open( os.path.join(FLAGS.multihop_output_dir, "subparas.json")) as f: all_sub_paras = json.load(f) tf.logging.info("Computing embeddings for %d mentions over %d paras.", len(mentions), len(all_sub_paras)) shard_size = len(mentions) // FLAGS.num_shards bert_predictor = bert_utils_v2.BERTPredictor(tokenizer, bert_ckpt) if FLAGS.my_shard is None: shard_range = range(FLAGS.num_shards + 1) else: shard_range = [FLAGS.my_shard] for ns in shard_range: min_ = ns * shard_size max_ = (ns + 1) * shard_size if min_ >= len(mentions): break if max_ > len(mentions): max_ = len(mentions) min_subp = mentions[min_][1] max_subp = mentions[max_ - 1][1] tf.logging.info("Processing shard %d of %d mentions and %d paras.", ns, max_ - min_, max_subp - min_subp + 1) para_emb = bert_predictor.get_doc_embeddings( all_sub_paras[min_subp:max_subp + 1]) assert para_emb.shape[2] == 2 * FLAGS.projection_dim mention_emb = np.empty((max_ - min_, 2 * bert_predictor.emb_dim), dtype=np.float32) for im, mention in enumerate(mentions[min_:max_]): mention_emb[im, :] = np.concatenate([ para_emb[mention[1] - min_subp, mention[2], :FLAGS.projection_dim], para_emb[mention[1] - min_subp, mention[3], FLAGS.projection_dim:2 * FLAGS.projection_dim] ]) del para_emb tf.logging.info( "Saving %d mention features to tensorflow checkpoint.", mention_emb.shape[0]) with tf.device("/cpu:0"): search_utils.write_to_checkpoint( "db_emb_%d" % ns, mention_emb, tf.float32, os.path.join(FLAGS.multihop_output_dir, "mention_feats_%d" % ns)) if FLAGS.do_combine: # Combine sharded DB into one. if FLAGS.shards_to_combine is None: shard_range = range(FLAGS.num_shards + 1) else: shard_range = range(FLAGS.shards_to_combine) with tf.device("/cpu:0"): all_db = [] for i in shard_range: ckpt_path = os.path.join(FLAGS.multihop_output_dir, "mention_feats_%d" % i) reader = tf.NewCheckpointReader(ckpt_path) var_to_shape_map = reader.get_variable_to_shape_map() tf.logging.info("Reading %s from %s with shape %s", "db_emb_%d" % i, ckpt_path, str(var_to_shape_map["db_emb_%d" % i])) tf_db = search_utils.load_database( "db_emb_%d" % i, var_to_shape_map["db_emb_%d" % i], ckpt_path) all_db.append(tf_db) tf.logging.info("Reading all variables.") session = tf.Session() session.run(tf.global_variables_initializer()) session.run(tf.local_variables_initializer()) np_db = session.run(all_db) tf.logging.info("Concatenating and storing.") np_db = np.concatenate(np_db, axis=0) search_utils.write_to_checkpoint( "db_emb", np_db, tf.float32, os.path.join(FLAGS.multihop_output_dir, "mention_feats"))
def tfidf_linking(questions, base_dir, tokenizer, top_k, batch_size=100): """Match questions to entities via Tf-IDF.""" # Load entity ids and masks. tf.reset_default_graph() id_ckpt = os.path.join(base_dir, "entity_ids") entity_ids = search_utils.load_database( "entity_ids", None, id_ckpt, dtype=tf.int32) mask_ckpt = os.path.join(base_dir, "entity_mask") entity_mask = search_utils.load_database("entity_mask", None, mask_ckpt) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) tf.logging.info("Loading entity ids and masks...") np_ent_ids, np_ent_mask = sess.run([entity_ids, entity_mask]) tf.logging.info("Building entity count matrix...") entity_count_matrix = search_utils.build_count_matrix(np_ent_ids, np_ent_mask) # Tokenize questions and build count matrix. tf.logging.info("Tokenizing questions...") ques_toks, ques_masks = [], [] for question in questions: toks = tokenizer.tokenize(question["question"]) tok_ids = tokenizer.convert_tokens_to_ids(toks) ques_toks.append(tok_ids) ques_masks.append([1 for _ in tok_ids]) tf.logging.info("Building question count matrix...") question_count_matrix = search_utils.build_count_matrix(ques_toks, ques_masks) # Tf-IDF. tf.logging.info("Computing IDFs...") idfs = search_utils.counts_to_idfs(entity_count_matrix, cutoff=1e-5) tf.logging.info("Computing entity Tf-IDFs...") ent_tfidfs = search_utils.counts_to_tfidf(entity_count_matrix, idfs) ent_tfidfs = normalize(ent_tfidfs, norm="l2", axis=0) tf.logging.info("Computing question TF-IDFs...") qry_tfidfs = search_utils.counts_to_tfidf(question_count_matrix, idfs) qry_tfidfs = normalize(qry_tfidfs, norm="l2", axis=0) tf.logging.info("Searching...") top_doc_indices = np.empty((len(questions), top_k), dtype=np.int32) top_doc_distances = np.empty((len(questions), top_k), dtype=np.float32) # distances = qry_tfidfs.transpose().dot(ent_tfidfs) num_batches = len(questions) // batch_size tf.logging.info("Computing distances in %d batches of size %d", num_batches + 1, batch_size) for nb in tqdm(range(num_batches + 1)): min_ = nb * batch_size max_ = (nb + 1) * batch_size if min_ >= len(questions): break if max_ > len(questions): max_ = len(questions) distances = qry_tfidfs[:, min_:max_].transpose().dot(ent_tfidfs).tocsr() for ii in range(min_, max_): my_distances = distances[ii - min_, :].tocsr() if len(my_distances.data) <= top_k: o_sort = np.argsort(-my_distances.data) top_doc_indices[ii, :len(o_sort)] = my_distances.indices[o_sort] top_doc_distances[ii, :len(o_sort)] = my_distances.data[o_sort] top_doc_indices[ii, len(o_sort):] = 0 top_doc_distances[ii, len(o_sort):] = 0 else: o_sort = np.argpartition(-my_distances.data, top_k)[:top_k] top_doc_indices[ii, :] = my_distances.indices[o_sort] top_doc_distances[ii, :] = my_distances.data[o_sort] # Load entity metadata and conver to kb_id. metadata_file = os.path.join(base_dir, "entities.json") entity2id, entity2name = json.load(tf.gfile.Open(metadata_file)) id2entity = {i: e for e, i in entity2id.items()} id2name = {i: entity2name[e] for e, i in entity2id.items()} mentions = [] for ii in range(len(questions)): my_mentions = [] for m in range(top_k): my_mentions.append({ "kb_id": id2entity[top_doc_indices[ii, m]], "score": str(top_doc_distances[ii, m]), "name": id2name[top_doc_indices[ii, m]], }) mentions.append(my_mentions) return mentions
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s", name, features[name].shape) is_training = (mode == tf.estimator.ModeKeys.TRAIN) # Initialize sparse tensors. with tf.device("/cpu:0"): tf_e2m_data, tf_e2m_indices, tf_e2m_rowsplits = ( search_utils.load_ragged_matrix("ent2ment", e2m_checkpoint)) with tf.name_scope("RaggedConstruction"): e2m_ragged_ind = tf.RaggedTensor.from_row_splits( values=tf_e2m_indices, row_splits=tf_e2m_rowsplits, validate=False) e2m_ragged_val = tf.RaggedTensor.from_row_splits( values=tf_e2m_data, row_splits=tf_e2m_rowsplits, validate=False) tf_m2e_map = search_utils.load_database("coref", [mips_config.num_mentions], m2e_checkpoint, dtype=tf.int32) entity_ids = search_utils.load_database( "entity_ids", [qa_config.num_entities, qa_config.max_entity_len], entity_id_checkpoint, dtype=tf.int32) entity_mask = search_utils.load_database( "entity_mask", [qa_config.num_entities, qa_config.max_entity_len], entity_mask_checkpoint) _, predictions = create_model_fn( bert_config=bert_config, qa_config=qa_config, mips_config=mips_config, is_training=is_training, features=features, ent2ment_ind=e2m_ragged_ind, ent2ment_val=e2m_ragged_val, ment2ent_map=tf_m2e_map, entity_ids=entity_ids, entity_mask=entity_mask, use_one_hot_embeddings=use_one_hot_embeddings, summary_obj=summary_obj) tvars = tf.trainable_variables() scaffold_fn = None if init_checkpoint: assignment_map, _ = get_assignment_map_from_checkpoint( tvars, init_checkpoint, load_only_bert=qa_config.load_only_bert) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) output_spec = None if mode == tf.estimator.ModeKeys.PREDICT: output_spec = contrib_tpu.TPUEstimatorSpec(mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) else: raise ValueError("Only PREDICT mode is supported: %s" % (mode)) return output_spec