예제 #1
0
 def setUpClass(cls):
     cls.index = (1, 2, 3)
     cls.triple_index = kgedata.TripleIndex()
     cls.triple_index.head = 1
     cls.triple_index.relation = 2
     cls.triple_index.tail = 3
     cls.literals = ("/m/1", "/m/2", "/m/3")
     cls.triple = kgedata.Triple()
     cls.triple.head = "/m/1"
     cls.triple.relation = "/m/2"
     cls.triple.tail = "/m/3"
예제 #2
0
    def __call__(self, t):
        for e in (t.head, t.tail):
            if e not in self._ents:
                self._ents[e] = self._ent_id
                self._ent_id += 1
        if t.relation not in self._rels:
            self._rels[t.relation] = self._rel_id
            self._rel_id += 1

        return kgedata.TripleIndex(self._ents[t.head], self._rels[t.relation],
                                   self._ents[t.tail])
예제 #3
0
def sieve_and_expand_triple(triple_source, entities, relations, head, relation,
                            tail):
    """Tile on a unknown element. returns a tuple of size 3 with h, r, t."""
    deprecation("Not tested anymore", since="0.3.0")

    batch_size, num_samples, num_element = batch.shape
    elements = np.split(batch, num_element, axis=2)
    # return (e.reshape(batch_size) for e in elements)

    if head == '?':
        r = relations[relation]
        t = entities[tail]
        triple_index = kgedata.TripleIndex(-1, r, t)

        h = np.arange(triple_source.num_entity, dtype=np.int64)
        r = np.tile(np.array([r], dtype=np.int64), triple_source.num_entity)
        t = np.tile(np.array([t], dtype=np.int64), triple_source.num_entity)
        prediction_type = constants.HEAD_KEY
    elif relation == '?':
        h = entities[head]
        t = entities[tail]
        triple_index = kgedata.TripleIndex(h, -1, t)

        h = np.tile(np.array([h], dtype=np.int64), triple_source.num_relation)
        r = np.arange(triple_source.num_relation, dtype=np.int64)
        t = np.tile(np.array([t], dtype=np.int64), triple_source.num_relation)
        prediction_type = constants.RELATION_KEY
    elif tail == '?':
        r = relations[relation]
        h = entities[head]
        triple_index = kgedata.TripleIndex(h, r, -1)

        h = np.tile(np.array([h], dtype=np.int64), triple_source.num_entity)
        r = np.tile(np.array([r], dtype=np.int64), triple_source.num_entity)
        t = np.arange(triple_source.num_entity, dtype=np.int64)
        prediction_type = constants.TAIL_KEY
    else:
        raise RuntimeError("head, relation, tail are known.")

    return (h, r, t), prediction_type, triple_index
 def test_attrs(self):
     self.assertEqual(self._indexer.entityIdMap(), {
         '/m/entity1': 0,
         '/m/entity2': 1,
         '/m/entity3': 2,
         '/m/entity4': 3
     })
     self.assertEqual(self._indexer.relationIdMap(), {
         '/country': 1,
         '/produced_by': 0,
         '/produced_in': 2
     })
     self.assertEqual(self._indexer.indexes(), [
         kgedata.TripleIndex(0, 0, 1),
         kgedata.TripleIndex(1, 1, 2),
         kgedata.TripleIndex(0, 2, 2),
         kgedata.TripleIndex(2, 2, 3),
         kgedata.TripleIndex(0, 1, 3)
     ])
     self.assertEqual(
         self._indexer.entities(),
         ['/m/entity1', '/m/entity2', '/m/entity3', '/m/entity4'])
     self.assertEqual(self._indexer.relations(),
                      ['/produced_by', '/country', '/produced_in'])
예제 #5
0
    def shrink_indexes_in_place(self, triples):
        """Uses a union find to find segment."""

        ent_flags = [False for i in range(self._ent_id)]
        rel_flags = [False for i in range(self._rel_id)]

        for t in triples:
            ent_flags[t.head] = True
            ent_flags[t.tail] = True
            rel_flags[t.relation] = True

        ents = bidict()
        ent_idx = 0
        for previous_idx, ent_exist in enumerate(ent_flags):
            if ent_exist:
                ents[self._ents.inverse[previous_idx]] = ent_idx
                ent_idx += 1
        logging.info(
            f"before shrinking: {self._ent_id}\nafter shrinking: {ent_idx}")

        rels = bidict()
        rel_idx = 0
        for previous_idx, rel_exist in enumerate(rel_flags):
            if rel_exist:
                rels[self._rels.inverse[previous_idx]] = rel_idx
                rel_idx += 1
        logging.info(
            f"before shrinking: {self._rel_id}\nafter shrinking: {rel_idx}")

        new_triples = [
            kgedata.TripleIndex(ents[self._ents.inverse[t.head]],
                                rels[self._rels.inverse[t.relation]],
                                ents[self._ents.inverse[t.tail]])
            for t in triples
        ]

        self._ent_id = ent_idx
        self._ents = ents

        self._rel_id = rel_idx
        self._rels = rels

        return new_triples
예제 #6
0
def build_index_and_mapping(triples):
    """index all triples into indexes and return their mappings"""
    ents = bidict()
    rels = bidict()
    ent_id = 0
    rel_id = 0

    collected = []
    for t in triples:
        for e in (t.head, t.tail):
            if e not in ents:
                ents[e] = ent_id
                ent_id += 1
        if t.relation not in rels:
            rels[t.relation] = rel_id
            rel_id += 1
        collected.append(
            kgedata.TripleIndex(ents[t.head], rels[t.relation], ents[t.tail]))

    return collected, ents, rels
예제 #7
0
def evaulate_prediction_np_collate(model, triple_source, config, ranker,
                                   data_loader):
  """use with NumpyCollate."""
  utils.deprecation("multiprocess validation is not in use any more", "0.5.0")
  model.eval()

  head_ranks = []
  filtered_head_ranks = []
  tail_ranks = []
  filtered_tail_ranks = []
  relation_ranks = []
  filtered_relation_ranks = []

  for i_batch, sample_batched in enumerate(data_loader):
    # sample_batched is a list of triple. triple has shape (1, 3). We need to tile it for the test.
    for triple in sample_batched:
      triple_index = kgedata.TripleIndex(*triple[0, :])

      if (config.report_dimension & stats.StatisticsDimension.SEPERATE_ENTITY
         ) or (config.report_dimension &
               stats.StatisticsDimension.COMBINED_ENTITY):
        _evaluate_predict_element(model, config, triple_index,
                                  triple_source.num_entity,
                                  data.TripleElement.HEAD, ranker.rank_head,
                                  head_ranks, filtered_head_ranks)
        _evaluate_predict_element(model, config, triple_index,
                                  triple_source.num_entity,
                                  data.TripleElement.TAIL, ranker.rank_tail,
                                  tail_ranks, filtered_tail_ranks)
      if config.report_dimension & stats.StatisticsDimension.RELATION:
        _evaluate_predict_element(
            model, config, triple_index, triple_source.num_relation,
            data.TripleElement.RELATION, ranker.rank_relation, relation_ranks,
            filtered_relation_ranks)

  return (head_ranks,
          filtered_head_ranks), (tail_ranks,
                                 filtered_tail_ranks), (relation_ranks,
                                                        filtered_relation_ranks)
예제 #8
0
def test_unwrap():
    assert kgekit.data.unpack(kgedata.TripleIndex(1, 2, 3)) == (1, 2, 3)
    assert kgekit.data.unpack(kgedata.Triple("/m/1", "/m/2", "/m/3")) == ("/m/1", "/m/2", "/m/3")
예제 #9
0
def test_pack_triples_numpy():
    np.testing.assert_equal(
        kgekit.data.pack_triples_numpy([kgedata.TripleIndex(1, 2, 3), kgedata.TripleIndex(4, 5, 6)]),
        np.array([[1,2,3], [4,5,6]], dtype=np.int64))
예제 #10
0
def test_transform_triple_numpy():
    np.testing.assert_equal(
        kgekit.data._transform_triple_numpy(kgedata.TripleIndex(1, 2, 3)),
        np.array([1,2,3], dtype=np.int64))
예제 #11
0
 def test_get_triple_index(self):
     self.assertEqual(kgedata.get_triple_index("1 2 3", "hrt", ' '),
                      kgedata.TripleIndex(1, 2, 3))
     self.assertEqual(kgedata.get_triple_index("1 2 3", "htr", ' '),
                      kgedata.TripleIndex(1, 3, 2))
예제 #12
0
 def test_triple_index(self):
     self.assertEqual(kgedata.TripleIndex(*self.index), self.triple_index)