Exemple #1
0
class ApproximatePageRank(object):
    """APR main lib which is used to wrap functions around ppr algo."""
    def __init__(self):
        self.data = CsrData()
        self.data.load_csr_data(full_wiki=FLAGS.full_wiki,
                                files_dir=FLAGS.apr_files_dir)

    def get_topk_extracted_ent(self, seeds, alpha, topk):
        """Extract topk entities given seeds.

    Args:
      seeds: An Ex1 vector with weight on every seed entity
      alpha: probability for PPR
      topk: max top entities to extract
    Returns:
      extracted_ents: list of selected entities
      extracted_scores: list of scores of selected entities
    """
        ppr_scores = csr_personalized_pagerank(seeds, self.data.adj_mat_t_csr,
                                               alpha)
        sorted_idx = np.argsort(ppr_scores)[::-1]
        extracted_ents = sorted_idx[:topk]
        extracted_scores = ppr_scores[sorted_idx[:topk]]

        # Check for really low values
        # Get idx of First value < 1e-6, limit extracted ents till there
        zero_idx = np.where(ppr_scores[extracted_ents] < 1e-6)[0]
        if zero_idx.shape[0] > 0:
            extracted_ents = extracted_ents[:zero_idx[0]]

        return extracted_ents, extracted_scores

    def get_facts(self, entities, topk, alpha, seed_weighting=True):
        """Get subgraph describing a neighbourhood around given entities.

    Args:
      entities: A list of Wikidata entities
      topk: Max entities to extract from PPR
      alpha: Node probability for PPR
      seed_weighting: Boolean for performing weighting seeds by freq in passage

    Returns:
      unique_facts: A list of unique facts around the seeds.
    """

        if FLAGS.verbose_logging:
            tf.logging.info('Getting subgraph')
        entity_ids = [
            int(self.data.ent2id[x]) for x in entities if x in self.data.ent2id
        ]
        if FLAGS.verbose_logging:
            tf.logging.info(
                str([
                    self.data.entity_names['e'][str(x)]['name']
                    for x in entity_ids
                ]))
        freq_dict = {x: entity_ids.count(x) for x in entity_ids}

        seed = np.zeros((self.data.adj_mat.shape[0], 1))
        if not seed_weighting:
            seed[entity_ids] = 1. / len(set(entity_ids))
        else:
            for x, y in freq_dict.items():
                seed[x] = y
            seed = seed / seed.sum()

        extracted_ents, extracted_scores = self.get_topk_extracted_ent(
            seed, alpha, topk)
        if FLAGS.verbose_logging:
            tf.logging.info('Extracted ents: ')
            tf.logging.info(
                str([
                    self.data.entity_names['e'][str(x)]['name']
                    for x in extracted_ents
                ]))

        facts = csr_topk_fact_extractor(self.data.adj_mat_t_csr,
                                        self.data.rel_dict, freq_dict,
                                        self.data.entity_names, extracted_ents,
                                        extracted_scores)
        if FLAGS.verbose_logging:
            tf.logging.info('Extracted facts: ')
            tf.logging.info(str(facts))

        # Extract 1 unique fact per pair of entities (fact with highest score)
        # Sort by scores
        unique_facts = {}
        for (sub, obj, rel, score) in facts:
            fwd_dir = (sub, obj)
            rev_dir = (obj, sub)
            if fwd_dir in unique_facts and score > unique_facts[fwd_dir][1]:
                unique_facts[fwd_dir] = (rel, score)
            elif rev_dir in unique_facts and score > unique_facts[rev_dir][1]:
                unique_facts[fwd_dir] = (rel, score)
                del unique_facts[rev_dir]  # Remove existing entity pair
            else:
                unique_facts[(sub, obj)] = (rel, score)
        unique_facts = list(unique_facts.items())
        return unique_facts
class ShortestPath(object):
    """Shortest Path main lib which is used to wrap functions around Shortest Path algo."""
    def __init__(self,
                 mode=None,
                 task_id=None,
                 shard_id=None,
                 question_id=None):
        self.data = CsrData()
        self.data.load_csr_data(full_wiki=FLAGS.full_wiki,
                                files_dir=FLAGS.apr_files_dir,
                                mode=mode,
                                task_id=task_id,
                                shard_id=shard_id,
                                question_id=question_id)
        self.high_freq_relations = {
            'P31': 'instance of',
            'P17': 'country',
            'P131': 'located in the administrative territorial entity',
            'P106': 'occupation',
            'P21': 'sex or gender',
            'P735': 'given name',
            'P27': 'country of citizenship',
            'P19': 'place of birth'
        }

    def get_khop_entities(self, seeds, k_hop):
        print("id2ent size: %d", len(self.data.id2ent))
        entity_ids = [
            int(self.data.ent2id[x]) for x in seeds if x in self.data.ent2id
        ]
        khop_entity_ids = csr_get_k_hop_entities(entity_ids,
                                                 self.data.adj_mat_t_csr,
                                                 k_hop)
        khop_entities = [
            self.data.id2ent[str(x)] for x in khop_entity_ids
            if str(x) in self.data.id2ent.keys()
        ]
        return khop_entities

    def get_topk_extracted_ent(self, seeds, alpha, topk):
        """Extract topk entities given seeds.

        Args:
          seeds: An Ex1 vector with weight on every seed entity
          alpha: probability for PPR
          topk: max top entities to extract
        Returns:
          extracted_ents: list of selected entities
          extracted_scores: list of scores of selected entities
        """
        #tf.logging.info('Start ppr')
        ppr_scores = csr_personalized_pagerank(seeds, self.data.adj_mat_t_csr,
                                               alpha)
        #tf.logging.info('End ppr')
        sorted_idx = np.argsort(ppr_scores)[::-1]
        extracted_ents = sorted_idx[:topk]
        extracted_scores = ppr_scores[sorted_idx[:topk]]

        # Check for really low values
        # Get idx of First value < 1e-6, limit extracted ents till there
        zero_idx = np.where(ppr_scores[extracted_ents] < 1e-6)[0]
        if zero_idx.shape[0] > 0:
            extracted_ents = extracted_ents[:zero_idx[0]]

        return extracted_ents, extracted_scores

    def get_augmented_facts(self, path, entity_names, augmentation_type=None):
        augmented_path = []
        for single_path in path:
            single_path = single_path[1:]
            for (obj_id, rel_id, subj_id) in reversed(single_path):
                if obj_id == subj_id:
                    continue
                subj_name = entity_names['e'][str(subj_id)]['name']
                obj_name = entity_names['e'][str(
                    obj_id)]['name'] if str(obj_id) != 'None' else 'None'
                rel_name = entity_names['r'][str(
                    rel_id)]['name'] if str(rel_id) != 'None' else 'None'
                augmented_path.append(
                    (((subj_id, subj_name), (obj_id, obj_name)),
                     ((rel_id, rel_name), None)))
        return augmented_path

    def get_all_path_augmented_facts(self,
                                     path,
                                     entity_names,
                                     augmentation_type=None):
        augmented_path = []
        for single_path in path:
            tmp_path = []
            single_path = single_path[1:]
            for (obj_id, rel_id, subj_id) in reversed(single_path):
                if obj_id == subj_id:
                    continue
                subj_name = entity_names['e'][str(subj_id)]['name']
                obj_name = entity_names['e'][str(
                    obj_id)]['name'] if str(obj_id) != 'None' else 'None'
                rel_name = entity_names['r'][str(
                    rel_id)]['name'] if str(rel_id) != 'None' else 'None'
                tmp_path.append((((subj_id, subj_name), (obj_id, obj_name)),
                                 ((rel_id, rel_name), None)))
            augmented_path.append(tmp_path)
        return augmented_path

    def get_shortest_path_facts(self,
                                question_entities,
                                answer_entities,
                                passage_entities,
                                seed_weighting=True,
                                fp=None):
        """Get subgraph describing shortest path from question to answer.

        Args:
          question_entities: A list of Wikidata entities
          answer_entities: A list of Wikidata entities
          passage_entities: A list of Wikidata entities

        Returns:
          unique_facts: A list of unique facts representing the shortest path.
        """

        if FLAGS.verbose_logging:
            print('Getting subgraph')
            tf.logging.info('Getting subgraph')
        question_entity_ids = [
            int(self.data.ent2id[x]) for x in question_entities
            if x in self.data.ent2id
        ]
        question_entity_names = str([
            self.data.entity_names['e'][str(x)]['name']
            for x in question_entity_ids
        ])
        #if fp is not None:
        #    fp.write(str(question_entities)+"\t"+question_entity_names+"\t")
        if FLAGS.verbose_logging:
            print('Question Entities')
            tf.logging.info('Question Entities')
            print(question_entities)
            print(question_entity_names)
            tf.logging.info(question_entity_names)

        answer_entity_ids = [
            int(self.data.ent2id[x]) for x in answer_entities
            if x in self.data.ent2id
        ]
        answer_entity_names = str([
            self.data.entity_names['e'][str(x)]['name']
            for x in answer_entity_ids
        ])
        #if fp is not None:
        #    fp.write(str(answer_entities)+"\t"+answer_entity_names+"\t")
        if FLAGS.verbose_logging:
            print('Answer Entities')
            tf.logging.info('Answer Entities')
            print(answer_entities)
            print(answer_entity_names)
            tf.logging.info(answer_entity_names)
        passage_entity_ids = [
            int(self.data.ent2id[x]) for x in passage_entities
            if x in self.data.ent2id
        ]
        passage_entity_names = str([
            self.data.entity_names['e'][str(x)]['name']
            for x in passage_entity_ids
        ])
        if FLAGS.verbose_logging:
            print('Passage Entities')
            tf.logging.info('Passage Entities')
            print(passage_entity_names)
            tf.logging.info(passage_entity_names)

        freq_dict = {
            x: question_entity_ids.count(x)
            for x in question_entity_ids
        }

        extracted_paths, num_hops = csr_get_shortest_path(
            question_entity_ids,
            self.data.adj_mat_t_csr,
            answer_entity_ids,
            self.data.rel_dict,
            k_hop=FLAGS.k_hop)
        augmented_facts = self.get_augmented_facts(extracted_paths,
                                                   self.data.entity_names)

        if FLAGS.verbose_logging:
            print('Extracted facts: ')
            print(str(augmented_facts))
            tf.logging.info('Extracted facts: ')
            tf.logging.info(str(augmented_facts))
            print("Num hops: " + str(num_hops))
        return augmented_facts, num_hops

    def get_question_to_passage_facts(self,
                                      question_entities,
                                      answer_entities,
                                      passage_entities,
                                      seed_weighting=True,
                                      fp=None):
        """Get subgraph describing shortest path from question to answer.

        Args:
          question_entities: A list of Wikidata entities
          answer_entities: A list of Wikidata entities
          passage_entities: A list of Wikidata entities

        Returns:
          unique_facts: A list of unique facts representing the shortest path.
        """

        if FLAGS.verbose_logging:
            print('Getting subgraph')
            tf.logging.info('Getting subgraph')
        question_entity_ids = [
            int(self.data.ent2id[x]) for x in question_entities
            if x in self.data.ent2id
        ]
        question_entity_names = str([
            self.data.entity_names['e'][str(x)]['name']
            for x in question_entity_ids
        ])
        #if fp is not None:
        #    fp.write(str(question_entities)+"\t"+question_entity_names+"\t")
        if FLAGS.verbose_logging:
            print('Question Entities')
            tf.logging.info('Question Entities')
            print(question_entities)
            print(question_entity_names)
            tf.logging.info(question_entity_names)

        answer_entity_ids = [
            int(self.data.ent2id[x]) for x in answer_entities
            if x in self.data.ent2id
        ]
        answer_entity_names = str([
            self.data.entity_names['e'][str(x)]['name']
            for x in answer_entity_ids
        ])
        #if fp is not None:
        #    fp.write(str(answer_entities)+"\t"+answer_entity_names+"\t")
        if FLAGS.verbose_logging:
            print('Answer Entities')
            tf.logging.info('Answer Entities')
            print(answer_entities)
            print(answer_entity_names)
            tf.logging.info(answer_entity_names)
        passage_entity_ids = [
            int(self.data.ent2id[x]) for x in passage_entities
            if x in self.data.ent2id
        ]
        passage_entity_names = str([
            self.data.entity_names['e'][str(x)]['name']
            for x in passage_entity_ids
        ])
        if FLAGS.verbose_logging:
            print('Passage Entities')
            tf.logging.info('Passage Entities')
            print(passage_entity_names)
            tf.logging.info(passage_entity_names)

        freq_dict = {
            x: question_entity_ids.count(x)
            for x in question_entity_ids
        }

        extracted_paths, num_hops = csr_get_all_paths(question_entity_ids,
                                                      self.data.adj_mat_t_csr,
                                                      passage_entity_ids,
                                                      self.data.rel_dict,
                                                      k_hop=FLAGS.k_hop)
        augmented_facts = self.get_all_path_augmented_facts(
            extracted_paths, self.data.entity_names)

        if FLAGS.verbose_logging:
            print('All path Extracted facts: ')
            print(str(augmented_facts))
            tf.logging.info('All path Extracted facts: ')
            tf.logging.info(str(augmented_facts))
            print("Num hops: " + str(num_hops))
        return augmented_facts, num_hops

    def get_all_path_facts(self,
                           question_entities,
                           answer_entities,
                           passage_entities,
                           seed_weighting=True,
                           fp=None):
        """Get subgraph describing shortest path from question to answer.

        Args:
          question_entities: A list of Wikidata entities
          answer_entities: A list of Wikidata entities
          passage_entities: A list of Wikidata entities

        Returns:
          unique_facts: A list of unique facts representing the shortest path.
        """

        if FLAGS.verbose_logging:
            print('Getting subgraph')
            tf.logging.info('Getting subgraph')
        question_entity_ids = [
            int(self.data.ent2id[x]) for x in question_entities
            if x in self.data.ent2id
        ]
        question_entity_names = str([
            self.data.entity_names['e'][str(x)]['name']
            for x in question_entity_ids
        ])
        #if fp is not None:
        #    fp.write(str(question_entities)+"\t"+question_entity_names+"\t")
        if FLAGS.verbose_logging:
            print('Question Entities')
            tf.logging.info('Question Entities')
            print(question_entities)
            print(question_entity_names)
            tf.logging.info(question_entity_names)

        answer_entity_ids = [
            int(self.data.ent2id[x]) for x in answer_entities
            if x in self.data.ent2id
        ]
        answer_entity_names = str([
            self.data.entity_names['e'][str(x)]['name']
            for x in answer_entity_ids
        ])
        #if fp is not None:
        #    fp.write(str(answer_entities)+"\t"+answer_entity_names+"\t")
        if FLAGS.verbose_logging:
            print('Answer Entities')
            tf.logging.info('Answer Entities')
            print(answer_entities)
            print(answer_entity_names)
            tf.logging.info(answer_entity_names)
        passage_entity_ids = [
            int(self.data.ent2id[x]) for x in passage_entities
            if x in self.data.ent2id
        ]
        passage_entity_names = str([
            self.data.entity_names['e'][str(x)]['name']
            for x in passage_entity_ids
        ])
        if FLAGS.verbose_logging:
            print('Passage Entities')
            tf.logging.info('Passage Entities')
            print(passage_entity_names)
            tf.logging.info(passage_entity_names)

        freq_dict = {
            x: question_entity_ids.count(x)
            for x in question_entity_ids
        }

        extracted_paths, num_hops = csr_get_all_paths(question_entity_ids,
                                                      self.data.adj_mat_t_csr,
                                                      answer_entity_ids,
                                                      self.data.rel_dict,
                                                      k_hop=FLAGS.k_hop)
        augmented_facts = self.get_all_path_augmented_facts(
            extracted_paths, self.data.entity_names)

        if FLAGS.verbose_logging:
            print('All path Extracted facts: ')
            print(str(augmented_facts))
            tf.logging.info('All path Extracted facts: ')
            tf.logging.info(str(augmented_facts))
            print("Num hops: " + str(num_hops))
        return augmented_facts, num_hops
class ApproximatePageRank(object):
    """APR main lib which is used to wrap functions around ppr algo."""
    def __init__(self,
                 mode=None,
                 task_id=None,
                 shard_id=None,
                 question_id=None,
                 apr_path=None):
        self.data = CsrData()
        apr_path = FLAGS.apr_files_dir if apr_path is None else apr_path
        self.data.load_csr_data(full_wiki=FLAGS.full_wiki,
                                files_dir=apr_path,
                                mode=mode,
                                task_id=task_id,
                                shard_id=shard_id,
                                question_id=question_id)
        self.high_freq_relations = {
            'P31': 'instance of',
            'P17': 'country',
            'P131': 'located in the administrative territorial entity',
            'P106': 'occupation',
            'P21': 'sex or gender',
            'P735': 'given name',
            'P27': 'country of citizenship',
            'P19': 'place of birth'
        }
        # self.relations_to_filter = json.load('relations_to_filter.json')

    def get_khop_entities(self, seeds, k_hop):
        print("id2ent size: %d", len(self.data.id2ent))
        entity_ids = [
            int(self.data.ent2id[x]) for x in seeds if x in self.data.ent2id
        ]
        khop_entity_ids = csr_get_k_hop_entities(entity_ids,
                                                 self.data.adj_mat_t_csr,
                                                 k_hop)
        khop_entities = [
            self.data.id2ent[str(x)] for x in khop_entity_ids
            if str(x) in self.data.id2ent.keys()
        ]
        return khop_entities

    def get_khop_facts(self, seeds, k_hop):
        seeds = list(set(seeds))
        #print("id2ent size: %d", len(self.data.id2ent))
        entity_ids = [
            int(self.data.ent2id[x]) for x in seeds if x in self.data.ent2id
        ]
        khop_entity_ids, khop_facts = csr_get_k_hop_facts(
            entity_ids, self.data.adj_mat_t_csr, self.data.rel_dict, k_hop)
        khop_entities = [
            self.data.id2ent[str(x)] for x in khop_entity_ids
            if str(x) in self.data.id2ent.keys()
        ]
        khop_facts = [((self.data.id2ent[str(s)],
                        self.data.entity_names['e'][str(s)]['name']),
                       (self.data.id2rel[r],
                        self.data.entity_names['r'][str(r)]['name']),
                       (self.data.id2ent[str(o)],
                        self.data.entity_names['e'][str(o)]['name']))
                      for (s, r, o) in khop_facts]
        return khop_entities, khop_facts

    def get_topk_extracted_ent(self, seeds, alpha, topk):
        """Extract topk entities given seeds.

    Args:
      seeds: An Ex1 vector with weight on every seed entity
      alpha: probability for PPR
      topk: max top entities to extract
    Returns:
      extracted_ents: list of selected entities
      extracted_scores: list of scores of selected entities
    """
        #tf.logging.info('Start ppr')
        ppr_scores = csr_personalized_pagerank(seeds, self.data.adj_mat_t_csr,
                                               alpha, self.data.entity_names)
        #tf.logging.info('End ppr')
        sorted_idx = np.argsort(ppr_scores)[::-1]
        extracted_ents = sorted_idx[:topk]
        extracted_scores = ppr_scores[sorted_idx[:topk]]

        # Check for really low values
        # Get idx of First value < 1e-6, limit extracted ents till there
        zero_idx = np.where(ppr_scores[extracted_ents] < 1e-6)[0]
        if zero_idx.shape[0] > 0:
            extracted_ents = extracted_ents[:zero_idx[0]]

        return extracted_ents, extracted_scores

    def get_facts(self, entities, topk, alpha, seed_weighting=True):
        """Get subgraph describing a neighbourhood around given entities.

    Args:
      entities: A list of Wikidata entities
      topk: Max entities to extract from PPR
      alpha: Node probability for PPR
      seed_weighting: Boolean for performing weighting seeds by freq in passage

    Returns:
      unique_facts: A list of unique facts around the seeds.
    """

        if FLAGS.verbose_logging:
            print('Getting subgraph')
            tf.logging.info('Getting subgraph')
        entity_ids = [
            int(self.data.ent2id[x]) for x in entities if x in self.data.ent2id
        ]
        if FLAGS.verbose_logging:
            print(
                str([
                    self.data.entity_names['e'][str(x)]['name']
                    for x in entity_ids
                ]))
            tf.logging.info(
                str([
                    self.data.entity_names['e'][str(x)]['name']
                    for x in entity_ids
                ]))
        freq_dict = {x: entity_ids.count(x) for x in entity_ids}

        seed = np.zeros((self.data.adj_mat_t_csr.shape[0], 1))
        if not seed_weighting:
            seed[entity_ids] = 1. / len(set(entity_ids))
        else:
            for x, y in freq_dict.items():
                seed[x] = y
            seed = seed / seed.sum()

        extracted_ents, extracted_scores = self.get_topk_extracted_ent(
            seed, alpha, topk)
        if FLAGS.verbose_logging:
            print('Extracted Ents')
            tf.logging.info('Extracted ents: ')
            tf.logging.info(
                str([
                    self.data.entity_names['e'][str(x)]['name']
                    for x in extracted_ents
                ]))
            print(
                str([(self.data.entity_names['e'][str(x)]['name'],
                      extracted_scores[idx])
                     for idx, x in enumerate(extracted_ents)][0:15]))

        facts = csr_topk_fact_extractor(self.data.adj_mat_t_csr,
                                        self.data.rel_dict, freq_dict,
                                        self.data.entity_names, extracted_ents,
                                        extracted_scores)
        if FLAGS.verbose_logging:
            #print('Extracted facts: ')
            #print(str(facts))
            tf.logging.info('Extracted facts: ')
            tf.logging.info(str(facts))

        # Extract 1 unique fact per pair of entities (fact with highest score)
        # Sort by scores
        unique_facts = {}
        for (sub, obj, rel, score) in facts:
            fwd_dir = (sub, obj, rel)
            rev_dir = (obj, sub, rel)
            if sub[1] == obj[1]:  #No self-links
                continue
            #fwd_dir = (sub[1], obj[1])
            #rev_dir = (obj[1], sub[1])

            if fwd_dir in unique_facts:
                if score > unique_facts[fwd_dir][1]:
                    unique_facts[fwd_dir] = (rel, score)
                else:
                    continue
            elif rev_dir in unique_facts:
                if score > unique_facts[rev_dir][1]:
                    unique_facts[fwd_dir] = (rel, score)
                    del unique_facts[rev_dir]  # Remove existing entity pair
                else:
                    continue
            else:
                unique_facts[fwd_dir] = (rel, score)
        unique_facts = list(unique_facts.items())
        return unique_facts

    def get_random_facts(self, entities, topk, alpha, seed_weighting=True):
        """Get random subgraph

      Args:
        entities: A list of Wikidata entities
        topk: Max entities to extract from PPR
        alpha: Node probability for PPR
        seed_weighting: Boolean for performing weighting seeds by freq in passage

      Returns:
        unique_facts: A list of unique random facts around the seeds.
      """
        #ent_ids = list(self.data.entity_names['e'].keys())
        ent_ids = [i for i in range(self.data.adj_mat_t_csr.shape[0])]
        extracted_ents = random.sample(ent_ids, 500)  # This doesn't work :(
        freq_dict = {}
        for i in extracted_ents:
            freq_dict[i] = 1
        extracted_scores = [1] * len(extracted_ents)
        facts = csr_topk_fact_extractor(self.data.adj_mat_t_csr,
                                        self.data.rel_dict, freq_dict,
                                        self.data.entity_names, extracted_ents,
                                        extracted_scores)
        if FLAGS.verbose_logging:
            tf.logging.info('Extracted facts: ')
            tf.logging.info(str(facts))

            # Extract 1 unique fact per pair of entities (fact with highest score)
            # Sort by scores
        unique_facts = {}
        for (sub, obj, rel, score) in facts:
            fwd_dir = (sub, obj)
            rev_dir = (obj, sub)
            if fwd_dir in unique_facts and score > unique_facts[fwd_dir][1]:
                unique_facts[fwd_dir] = (rel, score)
            elif rev_dir in unique_facts and score > unique_facts[rev_dir][1]:
                unique_facts[fwd_dir] = (rel, score)
                del unique_facts[rev_dir]  # Remove existing entity pair
            else:
                unique_facts[(sub, obj)] = (rel, score)
        unique_facts = list(unique_facts.items())
        return unique_facts

    def get_random_facts_of_question(self,
                                     question_entities,
                                     answer_entities,
                                     passage_entities,
                                     seed_weighting=True,
                                     fp=None):
        if FLAGS.verbose_logging:
            print('Getting subgraph')
            tf.logging.info('Getting subgraph')
        question_entity_ids = [
            int(self.data.ent2id[x]) for x in question_entities
            if x in self.data.ent2id
        ]
        random_facts = csr_get_random_facts_of_question(
            question_entity_ids, self.data.adj_mat_t_csr, answer_entities,
            self.data.rel_dict)
        if FLAGS.num_facts_limit > 0:
            random_facts = random_facts[0:FLAGS.num_facts_limit]
        augmented_facts = self.get_augmented_facts(random_facts,
                                                   self.data.entity_names)
        return augmented_facts

    def get_augmented_facts(self, path, entity_names, augmentation_type=None):
        augmented_path = []
        for single_path in path:
            single_path = single_path[1:]
            for (obj_id, rel_id, subj_id) in reversed(single_path):
                if obj_id == subj_id:
                    continue
                subj_name = entity_names['e'][str(subj_id)]['name']
                obj_name = entity_names['e'][str(
                    obj_id)]['name'] if str(obj_id) != 'None' else 'None'
                rel_name = entity_names['r'][str(
                    rel_id)]['name'] if str(rel_id) != 'None' else 'None'
                augmented_path.append(
                    (((subj_id, subj_name), (obj_id, obj_name)),
                     ((rel_id, rel_name), None)))
        return augmented_path

    def get_all_path_augmented_facts(self,
                                     path,
                                     entity_names,
                                     augmentation_type=None):
        augmented_path = []
        for single_path in path:
            tmp_path = []
            single_path = single_path[1:]
            for (obj_id, rel_id, subj_id) in reversed(single_path):
                if obj_id == subj_id:
                    continue
                subj_name = entity_names['e'][str(subj_id)]['name']
                obj_name = entity_names['e'][str(
                    obj_id)]['name'] if str(obj_id) != 'None' else 'None'
                rel_name = entity_names['r'][str(
                    rel_id)]['name'] if str(rel_id) != 'None' else 'None'
                tmp_path.append((((subj_id, subj_name), (obj_id, obj_name)),
                                 ((rel_id, rel_name), None)))
            augmented_path.append(tmp_path)
        return augmented_path

    def get_shortest_path_facts(self,
                                question_entities,
                                answer_entities,
                                passage_entities,
                                seed_weighting=True,
                                fp=None,
                                seperate_diff_paths=False,
                                filter_relations=False):
        """Get subgraph describing shortest path from question to answer.

      Args:
        question_entities: A list of Wikidata entities
        answer_entities: A list of Wikidata entities
        passage_entities: A list of Wikidata entities

      Returns:
        unique_facts: A list of unique facts representing the shortest path.
      """

        if FLAGS.verbose_logging:
            print('Getting subgraph')
            tf.logging.info('Getting subgraph')
        question_entity_ids = [
            int(self.data.ent2id[x]) for x in question_entities
            if x in self.data.ent2id
        ]
        question_entity_names = str([
            self.data.entity_names['e'][str(x)]['name']
            for x in question_entity_ids
        ])
        #if fp is not None:
        #    fp.write(str(question_entities)+"\t"+question_entity_names+"\t")
        if FLAGS.verbose_logging:
            print('Question Entities')
            tf.logging.info('Question Entities')
            print(question_entities)
            print(question_entity_names)
            tf.logging.info(question_entity_names)

        answer_entity_ids = [
            int(self.data.ent2id[x]) for x in answer_entities
            if x in self.data.ent2id
        ]
        answer_entity_names = str([
            self.data.entity_names['e'][str(x)]['name']
            for x in answer_entity_ids
        ])
        #if fp is not None:
        #    fp.write(str(answer_entities)+"\t"+answer_entity_names+"\t")
        if FLAGS.verbose_logging:
            print('Answer Entities')
            tf.logging.info('Answer Entities')
            print(answer_entities)
            print(answer_entity_names)
            tf.logging.info(answer_entity_names)
        passage_entity_ids = [
            int(self.data.ent2id[x]) for x in passage_entities
            if x in self.data.ent2id
        ]
        passage_entity_names = str([
            self.data.entity_names['e'][str(x)]['name']
            for x in passage_entity_ids
        ])
        if FLAGS.verbose_logging:
            print('Passage Entities')
            tf.logging.info('Passage Entities')
            print(passage_entity_names)
            tf.logging.info(passage_entity_names)

        freq_dict = {
            x: question_entity_ids.count(x)
            for x in question_entity_ids
        }

        extracted_paths, num_hops = csr_get_shortest_path(
            question_entity_ids,
            self.data.adj_mat_t_csr,
            answer_entity_ids,
            self.data.rel_dict,
            k_hop=FLAGS.k_hop,
            filter_relations=filter_relations,
            relations_to_filter=self.data.relations_to_filter,
            id2rel=self.data.id2rel)

        if seperate_diff_paths:
            augmented_facts = self.get_all_path_augmented_facts(
                extracted_paths, self.data.entity_names)
        else:
            augmented_facts = self.get_augmented_facts(extracted_paths,
                                                       self.data.entity_names)

        if FLAGS.verbose_logging:
            print('Extracted facts: ')
            print(str(augmented_facts))
            tf.logging.info('Extracted facts: ')
            tf.logging.info(str(augmented_facts))
            print("Num hops: " + str(num_hops))
        return augmented_facts, num_hops

    def get_question_links(self,
                           question_entities,
                           answer_entities,
                           passage_entities,
                           seed_weighting=True,
                           fp=None,
                           seperate_diff_paths=False,
                           filter_relations=False):
        """Get subgraph describing shortest path from question to answer.

      Args:
        question_entities: A list of Wikidata entities
        answer_entities: A list of Wikidata entities
        passage_entities: A list of Wikidata entities

      Returns:
        unique_facts: A list of unique facts representing the shortest path.
      """

        if FLAGS.verbose_logging:
            print('Getting subgraph')
            tf.logging.info('Getting subgraph')
        question_entity_ids = [
            int(self.data.ent2id[x]) for x in question_entities
            if x in self.data.ent2id
        ]
        question_entity_names = str([
            self.data.entity_names['e'][str(x)]['name']
            for x in question_entity_ids
        ])
        #if fp is not None:
        #    fp.write(str(question_entities)+"\t"+question_entity_names+"\t")
        if FLAGS.verbose_logging:
            print('Question Entities')
            tf.logging.info('Question Entities')
            print(question_entities)
            print(question_entity_names)
            tf.logging.info(question_entity_names)

        answer_entity_ids = [
            int(self.data.ent2id[x]) for x in answer_entities
            if x in self.data.ent2id
        ]
        answer_entity_names = str([
            self.data.entity_names['e'][str(x)]['name']
            for x in answer_entity_ids
        ])
        #if fp is not None:
        #    fp.write(str(answer_entities)+"\t"+answer_entity_names+"\t")
        if FLAGS.verbose_logging:
            print('Answer Entities')
            tf.logging.info('Answer Entities')
            print(answer_entities)
            print(answer_entity_names)
            tf.logging.info(answer_entity_names)
        passage_entity_ids = [
            int(self.data.ent2id[x]) for x in passage_entities
            if x in self.data.ent2id
        ]
        passage_entity_names = str([
            self.data.entity_names['e'][str(x)]['name']
            for x in passage_entity_ids
        ])
        if FLAGS.verbose_logging:
            print('Passage Entities')
            tf.logging.info('Passage Entities')
            print(passage_entity_names)
            tf.logging.info(passage_entity_names)

        freq_dict = {
            x: question_entity_ids.count(x)
            for x in question_entity_ids
        }

        extracted_facts, relations = csr_get_question_links(
            question_entity_ids,
            self.data.adj_mat_t_csr,
            answer_entity_ids,
            self.data.rel_dict,
            filter_relations=filter_relations,
            relations_to_filter=self.data.relations_to_filter,
            id2rel=self.data.id2rel)
        augmented_facts = []
        relations = []
        for (subj_id, rel_id, obj_id) in extracted_facts:
            if obj_id == subj_id:
                continue
            subj_name = self.data.entity_names['e'][str(subj_id)]['name']
            obj_name = self.data.entity_names['e'][str(
                obj_id)]['name'] if str(obj_id) != 'None' else 'None'
            rel_name = self.data.entity_names['r'][str(
                rel_id)]['name'] if str(rel_id) != 'None' else 'None'
            augmented_facts.append((((subj_id, subj_name), (obj_id, obj_name)),
                                    ((rel_id, rel_name), None)))
            rel_kb_id = self.data.id2rel[rel_id]
            relations.append((rel_id, rel_name))
        if FLAGS.verbose_logging:
            print('Extracted facts: ')
            print(str(augmented_facts))
            tf.logging.info('Extracted facts: ')
            tf.logging.info(str(augmented_facts))
        return augmented_facts, relations

    def get_question_to_passage_facts(self,
                                      question_entities,
                                      answer_entities,
                                      passage_entities,
                                      seed_weighting=True,
                                      fp=None):
        """Get subgraph describing shortest path from question to answer.

      Args:
        question_entities: A list of Wikidata entities
        answer_entities: A list of Wikidata entities
        passage_entities: A list of Wikidata entities

      Returns:
        unique_facts: A list of unique facts representing the shortest path.
      """

        if FLAGS.verbose_logging:
            print('Getting subgraph')
            tf.logging.info('Getting subgraph')
        question_entity_ids = [
            int(self.data.ent2id[x]) for x in question_entities
            if x in self.data.ent2id
        ]
        question_entity_names = str([
            self.data.entity_names['e'][str(x)]['name']
            for x in question_entity_ids
        ])
        #if fp is not None:
        #    fp.write(str(question_entities)+"\t"+question_entity_names+"\t")
        if FLAGS.verbose_logging:
            print('Question Entities')
            tf.logging.info('Question Entities')
            print(question_entities)
            print(question_entity_names)
            tf.logging.info(question_entity_names)

        answer_entity_ids = [
            int(self.data.ent2id[x]) for x in answer_entities
            if x in self.data.ent2id
        ]
        answer_entity_names = str([
            self.data.entity_names['e'][str(x)]['name']
            for x in answer_entity_ids
        ])
        #if fp is not None:
        #    fp.write(str(answer_entities)+"\t"+answer_entity_names+"\t")
        if FLAGS.verbose_logging:
            print('Answer Entities')
            tf.logging.info('Answer Entities')
            print(answer_entities)
            print(answer_entity_names)
            tf.logging.info(answer_entity_names)
        passage_entity_ids = [
            int(self.data.ent2id[x]) for x in passage_entities
            if x in self.data.ent2id
        ]
        passage_entity_names = str([
            self.data.entity_names['e'][str(x)]['name']
            for x in passage_entity_ids
        ])
        if FLAGS.verbose_logging:
            print('Passage Entities')
            tf.logging.info('Passage Entities')
            print(passage_entity_names)
            tf.logging.info(passage_entity_names)

        freq_dict = {
            x: question_entity_ids.count(x)
            for x in question_entity_ids
        }

        extracted_paths, num_hops = csr_get_all_paths(question_entity_ids,
                                                      self.data.adj_mat_t_csr,
                                                      passage_entity_ids,
                                                      self.data.rel_dict,
                                                      k_hop=FLAGS.k_hop)
        augmented_facts = self.get_all_path_augmented_facts(
            extracted_paths, self.data.entity_names)

        if FLAGS.verbose_logging:
            print('All path Extracted facts: ')
            print(str(augmented_facts))
            tf.logging.info('All path Extracted facts: ')
            tf.logging.info(str(augmented_facts))
            print("Num hops: " + str(num_hops))
        return augmented_facts, num_hops

    def get_all_path_facts(self,
                           question_entities,
                           answer_entities,
                           passage_entities,
                           seed_weighting=True,
                           fp=None):
        """Get subgraph describing shortest path from question to answer.

      Args:
        question_entities: A list of Wikidata entities
        answer_entities: A list of Wikidata entities
        passage_entities: A list of Wikidata entities

      Returns:
        unique_facts: A list of unique facts representing the shortest path.
      """

        if FLAGS.verbose_logging:
            print('Getting subgraph')
            tf.logging.info('Getting subgraph')
        question_entity_ids = [
            int(self.data.ent2id[x]) for x in question_entities
            if x in self.data.ent2id
        ]
        question_entity_names = str([
            self.data.entity_names['e'][str(x)]['name']
            for x in question_entity_ids
        ])
        #if fp is not None:
        #    fp.write(str(question_entities)+"\t"+question_entity_names+"\t")
        if FLAGS.verbose_logging:
            print('Question Entities')
            tf.logging.info('Question Entities')
            print(question_entities)
            print(question_entity_names)
            tf.logging.info(question_entity_names)

        answer_entity_ids = [
            int(self.data.ent2id[x]) for x in answer_entities
            if x in self.data.ent2id
        ]
        answer_entity_names = str([
            self.data.entity_names['e'][str(x)]['name']
            for x in answer_entity_ids
        ])
        #if fp is not None:
        #    fp.write(str(answer_entities)+"\t"+answer_entity_names+"\t")
        if FLAGS.verbose_logging:
            print('Answer Entities')
            tf.logging.info('Answer Entities')
            print(answer_entities)
            print(answer_entity_names)
            tf.logging.info(answer_entity_names)
        passage_entity_ids = [
            int(self.data.ent2id[x]) for x in passage_entities
            if x in self.data.ent2id
        ]
        passage_entity_names = str([
            self.data.entity_names['e'][str(x)]['name']
            for x in passage_entity_ids
        ])
        if FLAGS.verbose_logging:
            print('Passage Entities')
            tf.logging.info('Passage Entities')
            print(passage_entity_names)
            tf.logging.info(passage_entity_names)

        freq_dict = {
            x: question_entity_ids.count(x)
            for x in question_entity_ids
        }

        extracted_paths, num_hops = csr_get_all_paths(question_entity_ids,
                                                      self.data.adj_mat_t_csr,
                                                      answer_entity_ids,
                                                      self.data.rel_dict,
                                                      k_hop=FLAGS.k_hop)
        augmented_facts = self.get_all_path_augmented_facts(
            extracted_paths, self.data.entity_names)

        if FLAGS.verbose_logging:
            print('All path Extracted facts: ')
            print(str(augmented_facts))
            tf.logging.info('All path Extracted facts: ')
            tf.logging.info(str(augmented_facts))
            print("Num hops: " + str(num_hops))
        return augmented_facts, num_hops