def main(_): # Load all basic data. if FLAGS.do_init_concept2freq: do_init_concept2freq() return entity2id = index_corpus.load_concept_vocab(FLAGS.corpus_concept_vocab) if FLAGS.do_meantion_counting: with tf.gfile.Open(FLAGS.concept_frequency_dict_path) as f: logging.info("Reading %s", f.name) concept2freq = json.load(f) do_meantion_counting(concept2freq) if FLAGS.do_generate_entity_networks: with tf.gfile.Open(os.path.join(FLAGS.index_data_dir, "mentions.npy"), "rb") as f: logging.info("Reading %s", f.name) mentions = np.load(f) e2m_ragged_ind, _ = load_entity2mention() do_generate_entity_networks(entity2id, e2m_ragged_ind, mentions) if FLAGS.do_qa_hop_analysis: do_qa_hop_analysis(entity2id)
def main(_): global entity2id logging.set_verbosity(logging.INFO) logging.info("Reading CSQA(-formatted) data...") with tf.gfile.Open(FLAGS.csqa_file) as f: jsonlines = f.read().split("\n") data = [json.loads(jsonline) for jsonline in jsonlines if jsonline] logging.info("Done.") logging.info("Entity linking %d questions...", len(data)) all_questions = [] entity2id = index_corpus.load_concept_vocab(FLAGS.indexed_concept_file) linked_question_entities = [] for item in tqdm(data, desc="Matching concepts in the questions."): concept_mentions = index_corpus.simple_match( lemmas=item["question"]["lemmas"], concept_vocab=entity2id, max_len=4, disable_overlap=True) # Note: we want to limit size of init facts. qry_concept_set = set() qry_concept_list = [] for m in concept_mentions: c = m["mention"].lower() if c not in qry_concept_set: qry_concept_set.add(c) qry_concept_list.append({"kb_id": c, "name": c}) linked_question_entities.append(qry_concept_list) num_empty_questions = 0 num_empty_choices = 0 num_empty_answers = 0 for ii, item in tqdm(enumerate(data), desc="Processing", total=len(data)): # pylint: disable=g-complex-comprehension if item["answerKey"] in "ABCDE": truth_choice = ord(item["answerKey"]) - ord("A") # e.g., A-->0 elif item["answerKey"] in "12345": truth_choice = int(item["answerKey"]) - 1 choices = item["question"]["choices"] assert choices[truth_choice]["label"] == item["answerKey"] correct_answer = choices[truth_choice]["text"].lower() # Check the mentioned concepts in each choice. choice2concepts = {} for c in choices: mentioned_concepts = [] for m in index_corpus.simple_match( c["lemmas"], concept_vocab=entity2id, max_len=4, disable_overlap=FLAGS.disable_overlap): mentioned_concepts.append(m["mention"]) choice2concepts[c["text"].lower()] = set(mentioned_concepts) choice2concepts = remove_intersection(choice2concepts) non_empty_choices = sum( [bool(co) for _, co in choice2concepts.items()]) num_empty_choices += len(choices) - non_empty_choices if not linked_question_entities[ii]: num_empty_questions += 1 continue if not choice2concepts[correct_answer]: # the correct answer does not contain any concepts, skip it. num_empty_answers += 1 continue if FLAGS.do_filtering > 0: if non_empty_choices < FLAGS.do_filtering: continue choice2concepts = { k: sorted(list(v), key=lambda x: -len(x)) # Sort concepts by len. for k, v in choice2concepts.items() } sup_facts = choice2concepts[correct_answer] all_questions.append({ "question": item["question"]["stem"], "entities": linked_question_entities[ii], "answer": correct_answer, "_id": item["id"], "level": "N/A", # hotpotQA-specific keys: hard/medium/easy "type": "N/A", # hotpotQA-specific keys: comparison, bridge, etc. "supporting_facts": [{ "kb_id": c, "name": c } for c in sup_facts], # hotpotQA-specific keys "choice2concepts": choice2concepts, }) with tf.gfile.Open(FLAGS.output_file, "w") as f_out: logging.info("Writing questions to output file...%s", f_out.name) logging.info("Number of questions %d", len(all_questions)) f_out.write("\n".join(json.dumps(q) for q in all_questions)) logging.info("===============================================") logging.info("%d questions without entities (out of %d)", num_empty_questions, len(data)) logging.info("%d answers not IN entities (out of %d)", num_empty_answers, len(data)) logging.info("%d choices not IN entities (out of %d)", num_empty_choices, 5 * len(data))