def load_custom(data_path=None, split_size=[0.8, 0.1, 0.1]): """Load FB15k dataset. See `here <https://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data>`__ for paper by Bordes et al. originally presenting the dataset. Parameters ---------- data_home: str, optional Path to the `torchkge_data` directory (containing data folders). If files are not present on disk in this directory, they are downloaded and then placed in the right place. Returns ------- kg_train: torchkge.data_structures.KnowledgeGraph kg_val: torchkge.data_structures.KnowledgeGraph kg_test: torchkge.data_structures.KnowledgeGraph """ df = pd.read_csv(data_path, sep=",", header=None, names=["from", "rel", "to"]) df = shuffle(df) kg = KnowledgeGraph(df) train_size = int(len(df) * split_size[0]) test_size = int(len(df) * split_size[1]) valid_size = len(df) - (train_size + test_size) return kg.split_kg(sizes=(train_size, test_size, valid_size))
def load_fb15k237(data_home=None): """Load fb15k237 dataset. See `here <https://www.aclweb.org/anthology/D15-1174/>`__ for paper by Toutanova et al. originally presenting the dataset. Parameters ---------- data_home: str, optional Path to the `torchkge_data` directory (containing data folders). If files are not present on disk in this directory, they are downloaded and then placed in the right place. Returns ------- kg_train: torchkge.data_structures.KnowledgeGraph kg_val: torchkge.data_structures.KnowledgeGraph kg_test: torchkge.data_structures.KnowledgeGraph """ if data_home is None: data_home = get_data_home() data_path = data_home + '/FB15k237' if not exists(data_path): makedirs(data_path, exist_ok=True) urlretrieve( "https://graphs.telecom-paristech.fr/datasets/FB15k237.zip", data_home + '/FB15k237.zip') with zipfile.ZipFile(data_home + '/FB15k237.zip', 'r') as zip_ref: zip_ref.extractall(data_home) remove(data_home + '/FB15k237.zip') shutil.rmtree(data_home + '/__MACOSX') df1 = read_csv(data_path + '/train.txt', sep='\t', header=None, names=['from', 'rel', 'to']) df2 = read_csv(data_path + '/valid.txt', sep='\t', header=None, names=['from', 'rel', 'to']) df3 = read_csv(data_path + '/test.txt', sep='\t', header=None, names=['from', 'rel', 'to']) df = concat([df1, df2, df3]) kg = KnowledgeGraph(df) return kg.split_kg(sizes=(len(df1), len(df2), len(df3)))
def load_fb15k(data_home=None): """Load FB15k dataset. See `here <https://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data>`__ for paper by Bordes et al. originally presenting the dataset. Parameters ---------- data_home: str, optional Path to the `torchkge_data` directory (containing data folders). If files are not present on disk in this directory, they are downloaded and then placed in the right place. Returns ------- kg_train: torchkge.data_structures.KnowledgeGraph kg_val: torchkge.data_structures.KnowledgeGraph kg_test: torchkge.data_structures.KnowledgeGraph """ if data_home is None: data_home = get_data_home() data_path = data_home + '/FB15k' if not exists(data_path): makedirs(data_path, exist_ok=True) urlretrieve( "https://graphs.telecom-paristech.fr/data/torchkge/kgs/FB15k.zip", data_home + '/FB15k.zip') with zipfile.ZipFile(data_home + '/FB15k.zip', 'r') as zip_ref: zip_ref.extractall(data_home) remove(data_home + '/FB15k.zip') shutil.rmtree(data_home + '/__MACOSX') df1 = read_csv(data_path + '/freebase_mtr100_mte100-train.txt', sep='\t', header=None, names=['from', 'rel', 'to']) df2 = read_csv(data_path + '/freebase_mtr100_mte100-valid.txt', sep='\t', header=None, names=['from', 'rel', 'to']) df3 = read_csv(data_path + '/freebase_mtr100_mte100-test.txt', sep='\t', header=None, names=['from', 'rel', 'to']) df = concat([df1, df2, df3]) kg = KnowledgeGraph(df) return kg.split_kg(sizes=(len(df1), len(df2), len(df3)))
def test_get_bernoulli_probs(self): kg = KnowledgeGraph(df=self.df) probs = get_bernoulli_probs(kg) res = {0: 0.5714, 1: 0.5, 2: 0.5, 3: 0.5} for k in probs.keys(): assert (res[k] - probs[k]) < 1e-03
def load_wn18rr(data_home=None): """Load WN18RR dataset. See `here <https://arxiv.org/abs/1707.01476>`__ for paper by Dettmers et al. originally presenting the dataset. Parameters ---------- data_home: str, optional Path to the `torchkge_data` directory (containing data folders). If files are not present on disk in this directory, they are downloaded and then placed in the right place. Returns ------- kg_train: torchkge.data_structures.KnowledgeGraph kg_val: torchkge.data_structures.KnowledgeGraph kg_test: torchkge.data_structures.KnowledgeGraph """ if data_home is None: data_home = get_data_home() data_path = data_home + '/WN18RR' if not exists(data_path): makedirs(data_path, exist_ok=True) urlretrieve( "https://graphs.telecom-paristech.fr/data/torchkge/kgs/WN18RR.zip", data_home + '/WN18RR.zip') with zipfile.ZipFile(data_home + '/WN18RR.zip', 'r') as zip_ref: zip_ref.extractall(data_home) remove(data_home + '/WN18RR.zip') df1 = read_csv(data_path + '/train.txt', sep='\t', header=None, names=['from', 'rel', 'to']) df2 = read_csv(data_path + '/valid.txt', sep='\t', header=None, names=['from', 'rel', 'to']) df3 = read_csv(data_path + '/test.txt', sep='\t', header=None, names=['from', 'rel', 'to']) df = concat([df1, df2, df3]) kg = KnowledgeGraph(df) return kg.split_kg(sizes=(len(df1), len(df2), len(df3)))
def load_wn18(data_home=None): """Load WN18 dataset. Parameters ---------- data_home: str, optional Path to the `torchkge_data` directory (containing data folders). If files are not present on disk in this directory, they are downloaded and then placed in the right place. Returns ------- kg_train: torchkge.data_structures.KnowledgeGraph kg_val: torchkge.data_structures.KnowledgeGraph kg_test: torchkge.data_structures.KnowledgeGraph """ if data_home is None: data_home = get_data_home() data_path = data_home + '/WN18' if not exists(data_path): makedirs(data_path, exist_ok=True) urlretrieve( "https://graphs.telecom-paristech.fr/data/torchkge/kgs/WN18.zip", data_home + '/WN18.zip') with zipfile.ZipFile(data_home + '/WN18.zip', 'r') as zip_ref: zip_ref.extractall(data_home) remove(data_home + '/WN18.zip') shutil.rmtree(data_home + '/__MACOSX') df1 = read_csv(data_path + '/wordnet-mlj12-train.txt', sep='\t', header=None, names=['from', 'rel', 'to']) df2 = read_csv(data_path + '/wordnet-mlj12-valid.txt', sep='\t', header=None, names=['from', 'rel', 'to']) df3 = read_csv(data_path + '/wordnet-mlj12-test.txt', sep='\t', header=None, names=['from', 'rel', 'to']) df = concat([df1, df2, df3]) kg = KnowledgeGraph(df) return kg.split_kg(sizes=(len(df1), len(df2), len(df3)))
def test_KnowledgeGraph_Builder(self): assert len(self.kg) == 9 assert self.kg.n_ent == 6 assert self.kg.n_rel == 4 assert (type(self.kg.rel2ix) == dict) & (type(self.kg.rel2ix) == dict) assert (type(self.kg.head_idx) == Tensor) & (type(self.kg.tail_idx) == Tensor) & \ (type(self.kg.relations) == Tensor) assert (self.kg.head_idx.dtype == int64) & (self.kg.tail_idx.dtype == int64) & \ (self.kg.relations.dtype == int64) assert (len(self.kg.head_idx) == len(self.kg.tail_idx) == len(self.kg.relations)) kg_dict = {'heads': self.kg.head_idx, 'tails': self.kg.tail_idx, 'relations': self.kg.relations} with self.assertRaises(WrongArgumentsError): KnowledgeGraph() with self.assertRaises(WrongArgumentsError): KnowledgeGraph(kg=kg_dict, df=self.df) with self.assertRaises(WrongArgumentsError): KnowledgeGraph(kg=kg_dict) with self.assertRaises(WrongArgumentsError): KnowledgeGraph(kg={'heads': self.kg.head_idx, 'tails': self.kg.tail_idx}, ent2ix=self.kg.ent2ix, rel2ix=self.kg.rel2ix) with self.assertRaises(SanityError): KnowledgeGraph(kg={'heads': self.kg.head_idx[:-1], 'tails': self.kg.tail_idx, 'relations': self.kg.relations}, ent2ix=self.kg.ent2ix, rel2ix=self.kg.rel2ix) with self.assertRaises(SanityError): KnowledgeGraph(kg={'heads': self.kg.head_idx.int(), 'tails': self.kg.tail_idx, 'relations': self.kg.relations}, ent2ix=self.kg.ent2ix, rel2ix=self.kg.rel2ix)
class TestUtils(unittest.TestCase): """Tests for `torchkge.utils`.""" def setUp(self): self.df = pd.DataFrame([[0, 1, 0], [0, 2, 0], [0, 3, 0], [0, 4, 0], [1, 2, 1], [1, 3, 2], [2, 4, 0], [3, 4, 4], [5, 4, 0]], columns=['from', 'to', 'rel']) self.kg = KnowledgeGraph(self.df) def test_KnowledgeGraph_Builder(self): assert len(self.kg) == 9 assert self.kg.n_ent == 6 assert self.kg.n_rel == 4 assert (type(self.kg.rel2ix) == dict) & (type(self.kg.rel2ix) == dict) assert (type(self.kg.head_idx) == Tensor) & (type(self.kg.tail_idx) == Tensor) & \ (type(self.kg.relations) == Tensor) assert (self.kg.head_idx.dtype == int64) & (self.kg.tail_idx.dtype == int64) & \ (self.kg.relations.dtype == int64) assert (len(self.kg.head_idx) == len(self.kg.tail_idx) == len(self.kg.relations)) kg_dict = {'heads': self.kg.head_idx, 'tails': self.kg.tail_idx, 'relations': self.kg.relations} with self.assertRaises(WrongArgumentsError): KnowledgeGraph() with self.assertRaises(WrongArgumentsError): KnowledgeGraph(kg=kg_dict, df=self.df) with self.assertRaises(WrongArgumentsError): KnowledgeGraph(kg=kg_dict) with self.assertRaises(WrongArgumentsError): KnowledgeGraph(kg={'heads': self.kg.head_idx, 'tails': self.kg.tail_idx}, ent2ix=self.kg.ent2ix, rel2ix=self.kg.rel2ix) with self.assertRaises(SanityError): KnowledgeGraph(kg={'heads': self.kg.head_idx[:-1], 'tails': self.kg.tail_idx, 'relations': self.kg.relations}, ent2ix=self.kg.ent2ix, rel2ix=self.kg.rel2ix) with self.assertRaises(SanityError): KnowledgeGraph(kg={'heads': self.kg.head_idx.int(), 'tails': self.kg.tail_idx, 'relations': self.kg.relations}, ent2ix=self.kg.ent2ix, rel2ix=self.kg.rel2ix) def test_split_kg(self): assert (len(self.kg.split_kg()) == 2) & (len(self.kg.split_kg(validation=True)) == 3) with self.assertRaises(SizeMismatchError): self.kg.split_kg(sizes=(1, 2, 3, 4)) with self.assertRaises(WrongArgumentsError): self.kg.split_kg(sizes=(9, 9, 9)) with self.assertRaises(WrongArgumentsError): self.kg.split_kg(sizes=(9, 9))
class TestUtils(unittest.TestCase): def setUp(self): df = pd.DataFrame( [[0, 1, 0], [0, 2, 0], [0, 3, 0], [0, 4, 0], [1, 2, 1], [1, 3, 2], [2, 4, 0], [3, 4, 4], [5, 4, 0]], columns=['from', 'to', 'rel']) self.kg = KnowledgeGraph(df) def checkSanityLinkPrediction(self, evaluator): assert evaluator.rank_true_heads.dtype == long assert evaluator.rank_true_tails.dtype == long assert evaluator.filt_rank_true_heads.dtype == long assert evaluator.filt_rank_true_tails.dtype == long assert evaluator.rank_true_heads.shape[0] == len(self.kg) assert evaluator.rank_true_tails.shape[0] == len(self.kg) assert evaluator.filt_rank_true_heads.shape[0] == len(self.kg) assert evaluator.filt_rank_true_tails.shape[0] == len(self.kg) def test_LinkPredictionEvaluator(self): model = TransEModel(100, self.kg.n_ent, self.kg.n_rel, 'L1') evaluator = LinkPredictionEvaluator(model, self.kg) self.checkSanityLinkPrediction(evaluator) evaluator.evaluate(b_size=len(self.kg), k_max=10) self.checkSanityLinkPrediction(evaluator) def test_TripletClassificationEvaluator(self): model = TransEModel(100, self.kg.n_ent, self.kg.n_rel, 'L1') kg1, kg2 = self.kg.split_kg(sizes=(4, 5)) # kg2 contains all relations so it will be used as validation evaluator = TripletClassificationEvaluator(model, kg2, kg1) assert evaluator.thresholds is None assert not evaluator.evaluated evaluator.evaluate(b_size=len(self.kg)) assert evaluator.evaluated assert evaluator.thresholds is not None assert (len(evaluator.thresholds.shape) == 1) & (evaluator.thresholds.shape[0] == self.kg.n_rel)
def test_get_possible_heads_tails(self): kg = KnowledgeGraph(self.df) h, t = get_possible_heads_tails(kg) assert (type(h) == dict) & (type(t) == dict) assert h == {0: {0, 2, 5}, 1: {1}, 2: {1}, 3: {3}} assert t == {0: {1, 2, 3, 4}, 1: {2}, 2: {3}, 3: {4}} p_h, p_t = defaultdict(set), defaultdict(set) p_h[0].add(40) p_h[10].add(50) p_t[0].add(41) p_t[10].add(51) h, t = get_possible_heads_tails(kg, possible_heads=dict(p_h), possible_tails=dict(p_t)) assert h == {0: {0, 2, 5, 40}, 1: {1}, 2: {1}, 3: {3}, 10: {50}} assert t == {0: {1, 2, 3, 4, 41}, 1: {2}, 2: {3}, 3: {4}, 10: {51}}
def load_Sweden(data_home=None, GDR=False): """ Parameters ---------- data_home: str, optional Path to the `torchkge_data` directory (containing data folders). If files are not present on disk in this directory, they are downloaded and then placed in the right place. Returns ------- kg_train: torchkge.data_structures.KnowledgeGraph kg_val: torchkge.data_structures.KnowledgeGraph kg_test: torchkge.data_structures.KnowledgeGraph """ if data_home is None: data_home = get_data_home() data_path = data_home + '/Sweden' if GDR == True: geo = data_path + '/ent2point.txt' else: geo = None if exists(data_path + '/train.txt') and exists(data_path + '/test.txt') and exists(data_path + '/valid.txt'): df1 = read_csv(data_path + '/train.txt', sep='\t', header=None, names=['from', 'rel', 'to']) df2 = read_csv(data_path + '/valid.txt', sep='\t', header=None, names=['from', 'rel', 'to']) df3 = read_csv(data_path + '/test.txt', sep='\t', header=None, names=['from', 'rel', 'to']) df = concat([df1, df2, df3]) kg = KnowledgeGraph(df=df,geo=geo) kg_train, kg_val, kg_test = kg.split_kg(sizes=(len(df1), len(df2), len(df3)),geo=geo) return kg_train, kg_val, kg_test else: df = read_csv(data_path + '/triplets.txt', sep='\t', header=None, names=['from', 'rel', 'to'],encoding='utf-8') kg = KnowledgeGraph(df=df,geo=geo) kg_train, kg_val, kg_test = kg.split_kg(share=0.8, validation=True,geo=geo) data_save('/Sweden',kg_train, kg_val, kg_test,geo=geo) return kg_train, kg_val, kg_test
def load_wikidata_vitals(level=5, data_home=None): """Load knowledge graph extracted from Wikidata using the entities corresponding to Wikipedia pages contained in Wikivitals. See `here <https://netset.telecom-paris.fr/>`__ for details on Wikivitals and Wikivitals+ datasets. Parameters ---------- level: int (default=5) Either 4 or 5. data_home: str, optional Path to the `torchkge_data` directory (containing data folders). If files are not present on disk in this directory, they are downloaded and then placed in the right place. Returns ------- kg: torchkge.data_structures.KnowledgeGraph kg_attr: torchkge.data_structures.KnowledgeGraph """ assert level in [4, 5] if data_home is None: data_home = get_data_home() data_path = data_home + '/wikidatavitals-level{}'.format(level) if not exists(data_path): makedirs(data_path, exist_ok=True) urlretrieve( "https://graphs.telecom-paristech.fr/data/torchkge/kgs/wikidatavitals-level{}.zip" .format(level), data_home + '/wikidatavitals-level{}.zip'.format(level)) with zipfile.ZipFile( data_home + '/wikidatavitals-level{}.zip'.format(level), 'r') as zip_ref: zip_ref.extractall(data_home) remove(data_home + '/wikidatavitals-level{}.zip'.format(level)) df = read_csv(data_path + '/edges.tsv', sep='\t', names=['from', 'to', 'rel'], skiprows=1) attributes = read_csv(data_path + '/attributes.tsv', sep='\t', names=['from', 'to', 'rel'], skiprows=1) entities = read_csv(data_path + '/entities.tsv', sep='\t') relations = read_csv(data_path + '/relations.tsv', sep='\t') nodes = read_csv(data_path + '/nodes.tsv', sep='\t') df = enrich(df, entities, relations) attributes = enrich(attributes, entities, relations) relid2label = { relations.loc[i, 'wikidataID']: relations.loc[i, 'label'] for i in relations.index } entid2label = { entities.loc[i, 'wikidataID']: entities.loc[i, 'label'] for i in entities.index } entid2pagename = { nodes.loc[i, 'wikidataID']: nodes.loc[i, 'pageName'] for i in nodes.index } kg = KnowledgeGraph(df) kg_attr = KnowledgeGraph(attributes) kg.relid2label = relid2label kg_attr.relid2label = relid2label kg.entid2label = entid2label kg_attr.entid2label = entid2label kg.entid2pagename = entid2pagename kg_attr.entid2pagename = entid2pagename return kg, kg_attr
def load_wikidatasets(which, limit_=0, data_home=None): """Load WikiDataSets dataset. See `here <https://arxiv.org/abs/1906.04536>`__ for paper by Boschin et al. originally presenting the dataset. Parameters ---------- which: str String indicating which subset of Wikidata should be loaded. Available ones are `humans`, `companies`, `animals`, `countries` and `films`. limit_: int, optional (default=0) This indicates a lower limit on the number of neighbors an entity should have in the graph to be kept. data_home: str, optional Path to the `torchkge_data` directory (containing data folders). If files are not present on disk in this directory, they are downloaded and then placed in the right place. Returns ------- kg: torchkge.data_structures.KnowledgeGraph """ assert which in ['humans', 'companies', 'animals', 'countries', 'films'] if data_home is None: data_home = get_data_home() data_home = data_home + '/WikiDataSets' data_path = data_home + '/' + which if not exists(data_path): print(f"Downloading WikiDataSets/{which}") makedirs(data_path, exist_ok=True) urlretrieve( "https://graphs.telecom-paristech.fr/data/WikiDataSets/{}.tar.gz". format(which), data_home + '/{}.tar.gz'.format(which)) with tarfile.open(data_home + '/{}.tar.gz'.format(which), 'r') as tf: tf.extractall(data_home) remove(data_home + '/{}.tar.gz'.format(which)) # add entity2idx, relation2idx print( f"Creating Knowledge Graph Data Structure using WikiDataSets/{which}") df = read_csv(data_path + '/edges.tsv', sep='\t', names=['from', 'to', 'rel'], skiprows=[0]) entities = read_csv(data_path + '/entities.tsv', sep='\t', names=['id', 'wid', 'label'], skiprows=[0]) relations = read_csv(data_path + '/relations.tsv', sep='\t', names=['id', 'wid', 'label'], skiprows=[0]) ix2ent = {i: e for i, e in zip(entities['id'], entities['label'])} ix2rel = {i: r for i, r in zip(relations['id'], relations['label'])} for i in range(len(df)): h, t, r = df.loc[i]['from'], df.loc[i]['to'], df.loc[i]['rel'] df.loc[i] = [ix2ent[h], ix2ent[t], ix2rel[r]] entities.drop_duplicates('label', inplace=True) relations.drop_duplicates('label', inplace=True) ent2ix = {e: i for i, e in enumerate(entities['label'])} rel2ix = {r: i for i, r in enumerate(relations['label'])} if limit_ > 0: a = df.groupby('from').count()['rel'] b = df.groupby('to').count()['rel'] # Filter out nodes with too few facts tmp = merge( right=DataFrame(a).reset_index(), left=DataFrame(b).reset_index(), how='outer', right_on='from', left_on='to', ).fillna(0) tmp['rel'] = tmp['rel_x'] + tmp['rel_y'] tmp = tmp.drop(['from', 'rel_x', 'rel_y'], axis=1) tmp = tmp.loc[tmp['rel'] >= limit_] df_bis = df.loc[df['from'].isin(tmp['to']) | df['to'].isin(tmp['to'])] kg = KnowledgeGraph(df=df_bis, ent2ix=ent2ix, rel2ix=rel2ix) else: kg = KnowledgeGraph(df=df, ent2ix=ent2ix, rel2ix=rel2ix) return kg
def load_fb13(data_home=None): """Load FB13 dataset. Parameters ---------- data_home: str, optional Path to the `torchkge_data` directory (containing data folders). If files are not present on disk in this directory, they are downloaded and then placed in the right place. Returns ------- kg_train: torchkge.data_structures.KnowledgeGraph kg_val: torchkge.data_structures.KnowledgeGraph kg_test: torchkge.data_structures.KnowledgeGraph """ if data_home is None: data_home = get_data_home() data_path = data_home + '/FB13' if not exists(data_path): makedirs(data_path, exist_ok=True) urlretrieve( "https://graphs.telecom-paristech.fr/data/torchkge/kgs/FB13.zip", data_home + '/FB13.zip') with zipfile.ZipFile(data_home + '/FB13.zip', 'r') as zip_ref: zip_ref.extractall(data_home) remove(data_home + '/FB13.zip') df1 = read_csv(data_path + '/train2id.txt', sep='\t', header=None, names=['from', 'rel', 'to'], skiprows=[0]) df2 = read_csv(data_path + '/valid2id.txt', sep='\t', header=None, names=['from', 'rel', 'to'], skiprows=[0]) df3 = read_csv(data_path + '/test2id.txt', sep='\t', header=None, names=['from', 'rel', 'to'], skiprows=[0]) ent2idx = read_csv(data_path + '/entity2id.txt', sep='\t', header=None, names=['entity', 'idx'], skiprows=[0]) ent2idx = {e: i for e, i in zip(ent2idx['entity'], ent2idx['idx'])} rel2idx = read_csv(data_path + '/relation2id.txt', sep='\t', header=None, names=['relation', 'idx'], skiprows=[0]) rel2idx = {r: i for r, i in zip(rel2idx['relation'], rel2idx['idx'])} df = concat([df1, df2, df3]) kg = KnowledgeGraph(df=df, ent2ix=ent2idx, rel2ix=rel2idx) return kg.split_kg(sizes=(len(df1), len(df2), len(df3)))
def setUp(self): df = pd.DataFrame( [[0, 1, 0], [0, 2, 0], [0, 3, 0], [0, 4, 0], [1, 2, 1], [1, 3, 2], [2, 4, 0], [3, 4, 4], [5, 4, 0]], columns=['from', 'to', 'rel']) self.kg = KnowledgeGraph(df)
def test_get_hpt(self): kg = KnowledgeGraph(df=self.df) t = cat((kg.head_idx.view(-1, 1), kg.tail_idx.view( -1, 1), kg.relations.view(-1, 1)), dim=1) assert get_hpt(t) == {0: 1.5, 1: 1., 2: 1., 3: 1.}
def load_wikidatasets(which, limit_=None, data_home=None): """Load WikiDataSets dataset. See `here <https://arxiv.org/abs/1906.04536>`__ for paper by Boschin et al. originally presenting the dataset. Parameters ---------- which: str String indicating which subset of Wikidata should be loaded. Available ones are `humans`, `companies`, `animals`, `countries` and `films`. limit_: int, optional (default=0) This indicates a lower limit on the number of neighbors an entity should have in the graph to be kept. data_home: str, optional Path to the `torchkge_data` directory (containing data folders). If files are not present on disk in this directory, they are downloaded and then placed in the right place. Returns ------- kg_train: torchkge.data_structures.KnowledgeGraph kg_val: torchkge.data_structures.KnowledgeGraph kg_test: torchkge.data_structures.KnowledgeGraph """ assert which in ['humans', 'companies', 'animals', 'countries', 'films'] if data_home is None: data_home = get_data_home() data_home = data_home + '/WikiDataSets' data_path = data_home + '/' + which if not exists(data_path): makedirs(data_path, exist_ok=True) urlretrieve( "https://graphs.telecom-paristech.fr/WikiDataSets/{}.tar.gz". format(which), data_home + '/{}.tar.gz'.format(which)) with tarfile.open(data_home + '/{}.tar.gz'.format(which), 'r') as tf: tf.extractall(data_home) remove(data_home + '/{}.tar.gz'.format(which)) df = read_csv(data_path + '/edges.txt'.format(which), sep='\t', header=1, names=['from', 'to', 'rel']) a = df.groupby('from').count()['rel'] b = df.groupby('to').count()['rel'] # Filter out nodes with too few facts tmp = merge( right=DataFrame(a).reset_index(), left=DataFrame(b).reset_index(), how='outer', right_on='from', left_on='to', ).fillna(0) tmp['rel'] = tmp['rel_x'] + tmp['rel_y'] tmp = tmp.drop(['from', 'rel_x', 'rel_y'], axis=1) tmp = tmp.loc[tmp['rel'] >= limit_] df_bis = df.loc[df['from'].isin(tmp['to']) | df['to'].isin(tmp['to'])] kg = KnowledgeGraph(df_bis) kg_train, kg_val, kg_test = kg.split_kg(share=0.8, validation=True) return kg_train, kg_val, kg_test