def setUpClass(cls): cls.index = (1, 2, 3) cls.triple_index = kgedata.TripleIndex() cls.triple_index.head = 1 cls.triple_index.relation = 2 cls.triple_index.tail = 3 cls.literals = ("/m/1", "/m/2", "/m/3") cls.triple = kgedata.Triple() cls.triple.head = "/m/1" cls.triple.relation = "/m/2" cls.triple.tail = "/m/3"
def __call__(self, t): for e in (t.head, t.tail): if e not in self._ents: self._ents[e] = self._ent_id self._ent_id += 1 if t.relation not in self._rels: self._rels[t.relation] = self._rel_id self._rel_id += 1 return kgedata.TripleIndex(self._ents[t.head], self._rels[t.relation], self._ents[t.tail])
def sieve_and_expand_triple(triple_source, entities, relations, head, relation, tail): """Tile on a unknown element. returns a tuple of size 3 with h, r, t.""" deprecation("Not tested anymore", since="0.3.0") batch_size, num_samples, num_element = batch.shape elements = np.split(batch, num_element, axis=2) # return (e.reshape(batch_size) for e in elements) if head == '?': r = relations[relation] t = entities[tail] triple_index = kgedata.TripleIndex(-1, r, t) h = np.arange(triple_source.num_entity, dtype=np.int64) r = np.tile(np.array([r], dtype=np.int64), triple_source.num_entity) t = np.tile(np.array([t], dtype=np.int64), triple_source.num_entity) prediction_type = constants.HEAD_KEY elif relation == '?': h = entities[head] t = entities[tail] triple_index = kgedata.TripleIndex(h, -1, t) h = np.tile(np.array([h], dtype=np.int64), triple_source.num_relation) r = np.arange(triple_source.num_relation, dtype=np.int64) t = np.tile(np.array([t], dtype=np.int64), triple_source.num_relation) prediction_type = constants.RELATION_KEY elif tail == '?': r = relations[relation] h = entities[head] triple_index = kgedata.TripleIndex(h, r, -1) h = np.tile(np.array([h], dtype=np.int64), triple_source.num_entity) r = np.tile(np.array([r], dtype=np.int64), triple_source.num_entity) t = np.arange(triple_source.num_entity, dtype=np.int64) prediction_type = constants.TAIL_KEY else: raise RuntimeError("head, relation, tail are known.") return (h, r, t), prediction_type, triple_index
def test_attrs(self): self.assertEqual(self._indexer.entityIdMap(), { '/m/entity1': 0, '/m/entity2': 1, '/m/entity3': 2, '/m/entity4': 3 }) self.assertEqual(self._indexer.relationIdMap(), { '/country': 1, '/produced_by': 0, '/produced_in': 2 }) self.assertEqual(self._indexer.indexes(), [ kgedata.TripleIndex(0, 0, 1), kgedata.TripleIndex(1, 1, 2), kgedata.TripleIndex(0, 2, 2), kgedata.TripleIndex(2, 2, 3), kgedata.TripleIndex(0, 1, 3) ]) self.assertEqual( self._indexer.entities(), ['/m/entity1', '/m/entity2', '/m/entity3', '/m/entity4']) self.assertEqual(self._indexer.relations(), ['/produced_by', '/country', '/produced_in'])
def shrink_indexes_in_place(self, triples): """Uses a union find to find segment.""" ent_flags = [False for i in range(self._ent_id)] rel_flags = [False for i in range(self._rel_id)] for t in triples: ent_flags[t.head] = True ent_flags[t.tail] = True rel_flags[t.relation] = True ents = bidict() ent_idx = 0 for previous_idx, ent_exist in enumerate(ent_flags): if ent_exist: ents[self._ents.inverse[previous_idx]] = ent_idx ent_idx += 1 logging.info( f"before shrinking: {self._ent_id}\nafter shrinking: {ent_idx}") rels = bidict() rel_idx = 0 for previous_idx, rel_exist in enumerate(rel_flags): if rel_exist: rels[self._rels.inverse[previous_idx]] = rel_idx rel_idx += 1 logging.info( f"before shrinking: {self._rel_id}\nafter shrinking: {rel_idx}") new_triples = [ kgedata.TripleIndex(ents[self._ents.inverse[t.head]], rels[self._rels.inverse[t.relation]], ents[self._ents.inverse[t.tail]]) for t in triples ] self._ent_id = ent_idx self._ents = ents self._rel_id = rel_idx self._rels = rels return new_triples
def build_index_and_mapping(triples): """index all triples into indexes and return their mappings""" ents = bidict() rels = bidict() ent_id = 0 rel_id = 0 collected = [] for t in triples: for e in (t.head, t.tail): if e not in ents: ents[e] = ent_id ent_id += 1 if t.relation not in rels: rels[t.relation] = rel_id rel_id += 1 collected.append( kgedata.TripleIndex(ents[t.head], rels[t.relation], ents[t.tail])) return collected, ents, rels
def evaulate_prediction_np_collate(model, triple_source, config, ranker, data_loader): """use with NumpyCollate.""" utils.deprecation("multiprocess validation is not in use any more", "0.5.0") model.eval() head_ranks = [] filtered_head_ranks = [] tail_ranks = [] filtered_tail_ranks = [] relation_ranks = [] filtered_relation_ranks = [] for i_batch, sample_batched in enumerate(data_loader): # sample_batched is a list of triple. triple has shape (1, 3). We need to tile it for the test. for triple in sample_batched: triple_index = kgedata.TripleIndex(*triple[0, :]) if (config.report_dimension & stats.StatisticsDimension.SEPERATE_ENTITY ) or (config.report_dimension & stats.StatisticsDimension.COMBINED_ENTITY): _evaluate_predict_element(model, config, triple_index, triple_source.num_entity, data.TripleElement.HEAD, ranker.rank_head, head_ranks, filtered_head_ranks) _evaluate_predict_element(model, config, triple_index, triple_source.num_entity, data.TripleElement.TAIL, ranker.rank_tail, tail_ranks, filtered_tail_ranks) if config.report_dimension & stats.StatisticsDimension.RELATION: _evaluate_predict_element( model, config, triple_index, triple_source.num_relation, data.TripleElement.RELATION, ranker.rank_relation, relation_ranks, filtered_relation_ranks) return (head_ranks, filtered_head_ranks), (tail_ranks, filtered_tail_ranks), (relation_ranks, filtered_relation_ranks)
def test_unwrap(): assert kgekit.data.unpack(kgedata.TripleIndex(1, 2, 3)) == (1, 2, 3) assert kgekit.data.unpack(kgedata.Triple("/m/1", "/m/2", "/m/3")) == ("/m/1", "/m/2", "/m/3")
def test_pack_triples_numpy(): np.testing.assert_equal( kgekit.data.pack_triples_numpy([kgedata.TripleIndex(1, 2, 3), kgedata.TripleIndex(4, 5, 6)]), np.array([[1,2,3], [4,5,6]], dtype=np.int64))
def test_transform_triple_numpy(): np.testing.assert_equal( kgekit.data._transform_triple_numpy(kgedata.TripleIndex(1, 2, 3)), np.array([1,2,3], dtype=np.int64))
def test_get_triple_index(self): self.assertEqual(kgedata.get_triple_index("1 2 3", "hrt", ' '), kgedata.TripleIndex(1, 2, 3)) self.assertEqual(kgedata.get_triple_index("1 2 3", "htr", ' '), kgedata.TripleIndex(1, 3, 2))
def test_triple_index(self): self.assertEqual(kgedata.TripleIndex(*self.index), self.triple_index)