Exemplo n.º 1
0
def load_custom(data_path=None, split_size=[0.8, 0.1, 0.1]):
    """Load FB15k dataset. See `here
    <https://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data>`__
    for paper by Bordes et al. originally presenting the dataset.

    Parameters
    ----------
    data_home: str, optional
        Path to the `torchkge_data` directory (containing data folders). If
        files are not present on disk in this directory, they are downloaded
        and then placed in the right place.

    Returns
    -------
    kg_train: torchkge.data_structures.KnowledgeGraph
    kg_val: torchkge.data_structures.KnowledgeGraph
    kg_test: torchkge.data_structures.KnowledgeGraph
    """

    df = pd.read_csv(data_path,
                     sep=",",
                     header=None,
                     names=["from", "rel", "to"])
    df = shuffle(df)
    kg = KnowledgeGraph(df)
    train_size = int(len(df) * split_size[0])
    test_size = int(len(df) * split_size[1])
    valid_size = len(df) - (train_size + test_size)
    return kg.split_kg(sizes=(train_size, test_size, valid_size))
Exemplo n.º 2
0
def load_fb15k237(data_home=None):
    """Load fb15k237 dataset. See `here
    <https://www.aclweb.org/anthology/D15-1174/>`__ for paper by Toutanova et
    al. originally presenting the dataset.

    Parameters
    ----------
    data_home: str, optional
        Path to the `torchkge_data` directory (containing data folders). If
        files are not present on disk in this directory, they are downloaded
        and then placed in the right place.

    Returns
    -------
    kg_train: torchkge.data_structures.KnowledgeGraph
    kg_val: torchkge.data_structures.KnowledgeGraph
    kg_test: torchkge.data_structures.KnowledgeGraph

    """
    if data_home is None:
        data_home = get_data_home()
    data_path = data_home + '/FB15k237'
    if not exists(data_path):
        makedirs(data_path, exist_ok=True)
        urlretrieve(
            "https://graphs.telecom-paristech.fr/datasets/FB15k237.zip",
            data_home + '/FB15k237.zip')
        with zipfile.ZipFile(data_home + '/FB15k237.zip', 'r') as zip_ref:
            zip_ref.extractall(data_home)
        remove(data_home + '/FB15k237.zip')
        shutil.rmtree(data_home + '/__MACOSX')

    df1 = read_csv(data_path + '/train.txt',
                   sep='\t',
                   header=None,
                   names=['from', 'rel', 'to'])
    df2 = read_csv(data_path + '/valid.txt',
                   sep='\t',
                   header=None,
                   names=['from', 'rel', 'to'])
    df3 = read_csv(data_path + '/test.txt',
                   sep='\t',
                   header=None,
                   names=['from', 'rel', 'to'])
    df = concat([df1, df2, df3])
    kg = KnowledgeGraph(df)

    return kg.split_kg(sizes=(len(df1), len(df2), len(df3)))
Exemplo n.º 3
0
def load_fb15k(data_home=None):
    """Load FB15k dataset. See `here
    <https://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data>`__
    for paper by Bordes et al. originally presenting the dataset.

    Parameters
    ----------
    data_home: str, optional
        Path to the `torchkge_data` directory (containing data folders). If
        files are not present on disk in this directory, they are downloaded
        and then placed in the right place.

    Returns
    -------
    kg_train: torchkge.data_structures.KnowledgeGraph
    kg_val: torchkge.data_structures.KnowledgeGraph
    kg_test: torchkge.data_structures.KnowledgeGraph

    """
    if data_home is None:
        data_home = get_data_home()
    data_path = data_home + '/FB15k'
    if not exists(data_path):
        makedirs(data_path, exist_ok=True)
        urlretrieve(
            "https://graphs.telecom-paristech.fr/data/torchkge/kgs/FB15k.zip",
            data_home + '/FB15k.zip')
        with zipfile.ZipFile(data_home + '/FB15k.zip', 'r') as zip_ref:
            zip_ref.extractall(data_home)
        remove(data_home + '/FB15k.zip')
        shutil.rmtree(data_home + '/__MACOSX')

    df1 = read_csv(data_path + '/freebase_mtr100_mte100-train.txt',
                   sep='\t',
                   header=None,
                   names=['from', 'rel', 'to'])
    df2 = read_csv(data_path + '/freebase_mtr100_mte100-valid.txt',
                   sep='\t',
                   header=None,
                   names=['from', 'rel', 'to'])
    df3 = read_csv(data_path + '/freebase_mtr100_mte100-test.txt',
                   sep='\t',
                   header=None,
                   names=['from', 'rel', 'to'])
    df = concat([df1, df2, df3])
    kg = KnowledgeGraph(df)

    return kg.split_kg(sizes=(len(df1), len(df2), len(df3)))
Exemplo n.º 4
0
    def test_get_bernoulli_probs(self):
        kg = KnowledgeGraph(df=self.df)
        probs = get_bernoulli_probs(kg)
        res = {0: 0.5714, 1: 0.5, 2: 0.5, 3: 0.5}

        for k in probs.keys():
            assert (res[k] - probs[k]) < 1e-03
Exemplo n.º 5
0
def load_wn18rr(data_home=None):
    """Load WN18RR dataset. See `here
    <https://arxiv.org/abs/1707.01476>`__ for paper by Dettmers et
    al. originally presenting the dataset.

    Parameters
    ----------
    data_home: str, optional
        Path to the `torchkge_data` directory (containing data folders). If
        files are not present on disk in this directory, they are downloaded
        and then placed in the right place.

    Returns
    -------
    kg_train: torchkge.data_structures.KnowledgeGraph
    kg_val: torchkge.data_structures.KnowledgeGraph
    kg_test: torchkge.data_structures.KnowledgeGraph

    """
    if data_home is None:
        data_home = get_data_home()
    data_path = data_home + '/WN18RR'
    if not exists(data_path):
        makedirs(data_path, exist_ok=True)
        urlretrieve(
            "https://graphs.telecom-paristech.fr/data/torchkge/kgs/WN18RR.zip",
            data_home + '/WN18RR.zip')
        with zipfile.ZipFile(data_home + '/WN18RR.zip', 'r') as zip_ref:
            zip_ref.extractall(data_home)
        remove(data_home + '/WN18RR.zip')

    df1 = read_csv(data_path + '/train.txt',
                   sep='\t',
                   header=None,
                   names=['from', 'rel', 'to'])
    df2 = read_csv(data_path + '/valid.txt',
                   sep='\t',
                   header=None,
                   names=['from', 'rel', 'to'])
    df3 = read_csv(data_path + '/test.txt',
                   sep='\t',
                   header=None,
                   names=['from', 'rel', 'to'])
    df = concat([df1, df2, df3])
    kg = KnowledgeGraph(df)

    return kg.split_kg(sizes=(len(df1), len(df2), len(df3)))
Exemplo n.º 6
0
def load_wn18(data_home=None):
    """Load WN18 dataset.

    Parameters
    ----------
    data_home: str, optional
        Path to the `torchkge_data` directory (containing data folders). If
        files are not present on disk in this directory, they are downloaded
        and then placed in the right place.

    Returns
    -------
    kg_train: torchkge.data_structures.KnowledgeGraph
    kg_val: torchkge.data_structures.KnowledgeGraph
    kg_test: torchkge.data_structures.KnowledgeGraph

    """
    if data_home is None:
        data_home = get_data_home()
    data_path = data_home + '/WN18'
    if not exists(data_path):
        makedirs(data_path, exist_ok=True)
        urlretrieve(
            "https://graphs.telecom-paristech.fr/data/torchkge/kgs/WN18.zip",
            data_home + '/WN18.zip')
        with zipfile.ZipFile(data_home + '/WN18.zip', 'r') as zip_ref:
            zip_ref.extractall(data_home)
        remove(data_home + '/WN18.zip')
        shutil.rmtree(data_home + '/__MACOSX')

    df1 = read_csv(data_path + '/wordnet-mlj12-train.txt',
                   sep='\t',
                   header=None,
                   names=['from', 'rel', 'to'])
    df2 = read_csv(data_path + '/wordnet-mlj12-valid.txt',
                   sep='\t',
                   header=None,
                   names=['from', 'rel', 'to'])
    df3 = read_csv(data_path + '/wordnet-mlj12-test.txt',
                   sep='\t',
                   header=None,
                   names=['from', 'rel', 'to'])
    df = concat([df1, df2, df3])
    kg = KnowledgeGraph(df)

    return kg.split_kg(sizes=(len(df1), len(df2), len(df3)))
Exemplo n.º 7
0
    def test_KnowledgeGraph_Builder(self):
        assert len(self.kg) == 9
        assert self.kg.n_ent == 6
        assert self.kg.n_rel == 4
        assert (type(self.kg.rel2ix) == dict) & (type(self.kg.rel2ix) == dict)
        assert (type(self.kg.head_idx) == Tensor) & (type(self.kg.tail_idx) == Tensor) & \
               (type(self.kg.relations) == Tensor)
        assert (self.kg.head_idx.dtype == int64) & (self.kg.tail_idx.dtype == int64) & \
               (self.kg.relations.dtype == int64)
        assert (len(self.kg.head_idx) == len(self.kg.tail_idx) == len(self.kg.relations))

        kg_dict = {'heads': self.kg.head_idx, 'tails': self.kg.tail_idx, 'relations': self.kg.relations}
        with self.assertRaises(WrongArgumentsError):
            KnowledgeGraph()
        with self.assertRaises(WrongArgumentsError):
            KnowledgeGraph(kg=kg_dict, df=self.df)
        with self.assertRaises(WrongArgumentsError):
            KnowledgeGraph(kg=kg_dict)
        with self.assertRaises(WrongArgumentsError):
            KnowledgeGraph(kg={'heads': self.kg.head_idx, 'tails': self.kg.tail_idx},
                           ent2ix=self.kg.ent2ix, rel2ix=self.kg.rel2ix)
        with self.assertRaises(SanityError):
            KnowledgeGraph(kg={'heads': self.kg.head_idx[:-1],
                               'tails': self.kg.tail_idx,
                               'relations': self.kg.relations},
                           ent2ix=self.kg.ent2ix, rel2ix=self.kg.rel2ix)
        with self.assertRaises(SanityError):
            KnowledgeGraph(kg={'heads': self.kg.head_idx.int(),
                               'tails': self.kg.tail_idx,
                               'relations': self.kg.relations},
                           ent2ix=self.kg.ent2ix, rel2ix=self.kg.rel2ix)
Exemplo n.º 8
0
class TestUtils(unittest.TestCase):
    """Tests for `torchkge.utils`."""

    def setUp(self):
        self.df = pd.DataFrame([[0, 1, 0], [0, 2, 0], [0, 3, 0], [0, 4, 0], [1, 2, 1], [1, 3, 2], [2, 4, 0], [3, 4, 4],
                                [5, 4, 0]], columns=['from', 'to', 'rel'])
        self.kg = KnowledgeGraph(self.df)

    def test_KnowledgeGraph_Builder(self):
        assert len(self.kg) == 9
        assert self.kg.n_ent == 6
        assert self.kg.n_rel == 4
        assert (type(self.kg.rel2ix) == dict) & (type(self.kg.rel2ix) == dict)
        assert (type(self.kg.head_idx) == Tensor) & (type(self.kg.tail_idx) == Tensor) & \
               (type(self.kg.relations) == Tensor)
        assert (self.kg.head_idx.dtype == int64) & (self.kg.tail_idx.dtype == int64) & \
               (self.kg.relations.dtype == int64)
        assert (len(self.kg.head_idx) == len(self.kg.tail_idx) == len(self.kg.relations))

        kg_dict = {'heads': self.kg.head_idx, 'tails': self.kg.tail_idx, 'relations': self.kg.relations}
        with self.assertRaises(WrongArgumentsError):
            KnowledgeGraph()
        with self.assertRaises(WrongArgumentsError):
            KnowledgeGraph(kg=kg_dict, df=self.df)
        with self.assertRaises(WrongArgumentsError):
            KnowledgeGraph(kg=kg_dict)
        with self.assertRaises(WrongArgumentsError):
            KnowledgeGraph(kg={'heads': self.kg.head_idx, 'tails': self.kg.tail_idx},
                           ent2ix=self.kg.ent2ix, rel2ix=self.kg.rel2ix)
        with self.assertRaises(SanityError):
            KnowledgeGraph(kg={'heads': self.kg.head_idx[:-1],
                               'tails': self.kg.tail_idx,
                               'relations': self.kg.relations},
                           ent2ix=self.kg.ent2ix, rel2ix=self.kg.rel2ix)
        with self.assertRaises(SanityError):
            KnowledgeGraph(kg={'heads': self.kg.head_idx.int(),
                               'tails': self.kg.tail_idx,
                               'relations': self.kg.relations},
                           ent2ix=self.kg.ent2ix, rel2ix=self.kg.rel2ix)

    def test_split_kg(self):
        assert (len(self.kg.split_kg()) == 2) & (len(self.kg.split_kg(validation=True)) == 3)
        with self.assertRaises(SizeMismatchError):
            self.kg.split_kg(sizes=(1, 2, 3, 4))
        with self.assertRaises(WrongArgumentsError):
            self.kg.split_kg(sizes=(9, 9, 9))
        with self.assertRaises(WrongArgumentsError):
            self.kg.split_kg(sizes=(9, 9))
Exemplo n.º 9
0
class TestUtils(unittest.TestCase):
    def setUp(self):
        df = pd.DataFrame(
            [[0, 1, 0], [0, 2, 0], [0, 3, 0], [0, 4, 0], [1, 2, 1], [1, 3, 2],
             [2, 4, 0], [3, 4, 4], [5, 4, 0]],
            columns=['from', 'to', 'rel'])
        self.kg = KnowledgeGraph(df)

    def checkSanityLinkPrediction(self, evaluator):
        assert evaluator.rank_true_heads.dtype == long
        assert evaluator.rank_true_tails.dtype == long
        assert evaluator.filt_rank_true_heads.dtype == long
        assert evaluator.filt_rank_true_tails.dtype == long

        assert evaluator.rank_true_heads.shape[0] == len(self.kg)
        assert evaluator.rank_true_tails.shape[0] == len(self.kg)
        assert evaluator.filt_rank_true_heads.shape[0] == len(self.kg)
        assert evaluator.filt_rank_true_tails.shape[0] == len(self.kg)

    def test_LinkPredictionEvaluator(self):
        model = TransEModel(100, self.kg.n_ent, self.kg.n_rel, 'L1')

        evaluator = LinkPredictionEvaluator(model, self.kg)
        self.checkSanityLinkPrediction(evaluator)

        evaluator.evaluate(b_size=len(self.kg), k_max=10)
        self.checkSanityLinkPrediction(evaluator)

    def test_TripletClassificationEvaluator(self):
        model = TransEModel(100, self.kg.n_ent, self.kg.n_rel, 'L1')
        kg1, kg2 = self.kg.split_kg(sizes=(4, 5))
        # kg2 contains all relations so it will be used as validation
        evaluator = TripletClassificationEvaluator(model, kg2, kg1)
        assert evaluator.thresholds is None
        assert not evaluator.evaluated

        evaluator.evaluate(b_size=len(self.kg))
        assert evaluator.evaluated
        assert evaluator.thresholds is not None
        assert (len(evaluator.thresholds.shape)
                == 1) & (evaluator.thresholds.shape[0] == self.kg.n_rel)
Exemplo n.º 10
0
    def test_get_possible_heads_tails(self):
        kg = KnowledgeGraph(self.df)
        h, t = get_possible_heads_tails(kg)

        assert (type(h) == dict) & (type(t) == dict)

        assert h == {0: {0, 2, 5}, 1: {1}, 2: {1}, 3: {3}}
        assert t == {0: {1, 2, 3, 4}, 1: {2}, 2: {3}, 3: {4}}

        p_h, p_t = defaultdict(set), defaultdict(set)
        p_h[0].add(40)
        p_h[10].add(50)
        p_t[0].add(41)
        p_t[10].add(51)

        h, t = get_possible_heads_tails(kg,
                                        possible_heads=dict(p_h),
                                        possible_tails=dict(p_t))

        assert h == {0: {0, 2, 5, 40}, 1: {1}, 2: {1}, 3: {3}, 10: {50}}
        assert t == {0: {1, 2, 3, 4, 41}, 1: {2}, 2: {3}, 3: {4}, 10: {51}}
Exemplo n.º 11
0
def load_Sweden(data_home=None, GDR=False):
    """


    Parameters
    ----------
    data_home: str, optional
        Path to the `torchkge_data` directory (containing data folders). If
        files are not present on disk in this directory, they are downloaded
        and then placed in the right place.

    Returns
    -------
    kg_train: torchkge.data_structures.KnowledgeGraph
    kg_val: torchkge.data_structures.KnowledgeGraph
    kg_test: torchkge.data_structures.KnowledgeGraph

    """
    if data_home is None:
        data_home = get_data_home()
    data_path = data_home + '/Sweden'
    if GDR == True:
        geo = data_path + '/ent2point.txt'
    else:
        geo = None

    if exists(data_path + '/train.txt') and exists(data_path + '/test.txt') and exists(data_path + '/valid.txt'):
        df1 = read_csv(data_path + '/train.txt',
                       sep='\t', header=None, names=['from', 'rel', 'to'])
        df2 = read_csv(data_path + '/valid.txt',
                       sep='\t', header=None, names=['from', 'rel', 'to'])
        df3 = read_csv(data_path + '/test.txt',
                       sep='\t', header=None, names=['from', 'rel', 'to'])
        df = concat([df1, df2, df3])
        kg = KnowledgeGraph(df=df,geo=geo)
        kg_train, kg_val, kg_test = kg.split_kg(sizes=(len(df1), len(df2), len(df3)),geo=geo)
        return kg_train, kg_val, kg_test
    else:
        df = read_csv(data_path + '/triplets.txt',
                       sep='\t', header=None, names=['from', 'rel', 'to'],encoding='utf-8')
        kg = KnowledgeGraph(df=df,geo=geo)
        kg_train, kg_val, kg_test = kg.split_kg(share=0.8, validation=True,geo=geo)
        data_save('/Sweden',kg_train, kg_val, kg_test,geo=geo)
        return kg_train, kg_val, kg_test
Exemplo n.º 12
0
def load_wikidata_vitals(level=5, data_home=None):
    """Load knowledge graph extracted from Wikidata using the entities
    corresponding to Wikipedia pages contained in Wikivitals. See `here
    <https://netset.telecom-paris.fr/>`__ for details on Wikivitals and
    Wikivitals+ datasets.

    Parameters
    ----------
    level: int (default=5)
        Either 4 or 5.
    data_home: str, optional
        Path to the `torchkge_data` directory (containing data folders). If
        files are not present on disk in this directory, they are downloaded
        and then placed in the right place.

    Returns
    -------
    kg: torchkge.data_structures.KnowledgeGraph
    kg_attr: torchkge.data_structures.KnowledgeGraph
    """
    assert level in [4, 5]

    if data_home is None:
        data_home = get_data_home()

    data_path = data_home + '/wikidatavitals-level{}'.format(level)

    if not exists(data_path):
        makedirs(data_path, exist_ok=True)
        urlretrieve(
            "https://graphs.telecom-paristech.fr/data/torchkge/kgs/wikidatavitals-level{}.zip"
            .format(level),
            data_home + '/wikidatavitals-level{}.zip'.format(level))

        with zipfile.ZipFile(
                data_home + '/wikidatavitals-level{}.zip'.format(level),
                'r') as zip_ref:
            zip_ref.extractall(data_home)
        remove(data_home + '/wikidatavitals-level{}.zip'.format(level))

    df = read_csv(data_path + '/edges.tsv',
                  sep='\t',
                  names=['from', 'to', 'rel'],
                  skiprows=1)
    attributes = read_csv(data_path + '/attributes.tsv',
                          sep='\t',
                          names=['from', 'to', 'rel'],
                          skiprows=1)

    entities = read_csv(data_path + '/entities.tsv', sep='\t')
    relations = read_csv(data_path + '/relations.tsv', sep='\t')
    nodes = read_csv(data_path + '/nodes.tsv', sep='\t')

    df = enrich(df, entities, relations)
    attributes = enrich(attributes, entities, relations)

    relid2label = {
        relations.loc[i, 'wikidataID']: relations.loc[i, 'label']
        for i in relations.index
    }
    entid2label = {
        entities.loc[i, 'wikidataID']: entities.loc[i, 'label']
        for i in entities.index
    }
    entid2pagename = {
        nodes.loc[i, 'wikidataID']: nodes.loc[i, 'pageName']
        for i in nodes.index
    }

    kg = KnowledgeGraph(df)
    kg_attr = KnowledgeGraph(attributes)

    kg.relid2label = relid2label
    kg_attr.relid2label = relid2label
    kg.entid2label = entid2label
    kg_attr.entid2label = entid2label
    kg.entid2pagename = entid2pagename
    kg_attr.entid2pagename = entid2pagename

    return kg, kg_attr
Exemplo n.º 13
0
def load_wikidatasets(which, limit_=0, data_home=None):
    """Load WikiDataSets dataset. See `here
    <https://arxiv.org/abs/1906.04536>`__ for paper by Boschin et al.
    originally presenting the dataset.

    Parameters
    ----------
    which: str
        String indicating which subset of Wikidata should be loaded.
        Available ones are `humans`, `companies`, `animals`, `countries` and
        `films`.
    limit_: int, optional (default=0)
        This indicates a lower limit on the number of neighbors an entity
        should have in the graph to be kept.
    data_home: str, optional
        Path to the `torchkge_data` directory (containing data folders). If
        files are not present on disk in this directory, they are downloaded
        and then placed in the right place.

    Returns
    -------
    kg: torchkge.data_structures.KnowledgeGraph

    """
    assert which in ['humans', 'companies', 'animals', 'countries', 'films']

    if data_home is None:
        data_home = get_data_home()

    data_home = data_home + '/WikiDataSets'
    data_path = data_home + '/' + which
    if not exists(data_path):
        print(f"Downloading WikiDataSets/{which}")
        makedirs(data_path, exist_ok=True)
        urlretrieve(
            "https://graphs.telecom-paristech.fr/data/WikiDataSets/{}.tar.gz".
            format(which), data_home + '/{}.tar.gz'.format(which))

        with tarfile.open(data_home + '/{}.tar.gz'.format(which), 'r') as tf:
            tf.extractall(data_home)
        remove(data_home + '/{}.tar.gz'.format(which))

    # add entity2idx, relation2idx
    print(
        f"Creating Knowledge Graph Data Structure using WikiDataSets/{which}")
    df = read_csv(data_path + '/edges.tsv',
                  sep='\t',
                  names=['from', 'to', 'rel'],
                  skiprows=[0])
    entities = read_csv(data_path + '/entities.tsv',
                        sep='\t',
                        names=['id', 'wid', 'label'],
                        skiprows=[0])
    relations = read_csv(data_path + '/relations.tsv',
                         sep='\t',
                         names=['id', 'wid', 'label'],
                         skiprows=[0])

    ix2ent = {i: e for i, e in zip(entities['id'], entities['label'])}
    ix2rel = {i: r for i, r in zip(relations['id'], relations['label'])}

    for i in range(len(df)):
        h, t, r = df.loc[i]['from'], df.loc[i]['to'], df.loc[i]['rel']
        df.loc[i] = [ix2ent[h], ix2ent[t], ix2rel[r]]

    entities.drop_duplicates('label', inplace=True)
    relations.drop_duplicates('label', inplace=True)

    ent2ix = {e: i for i, e in enumerate(entities['label'])}
    rel2ix = {r: i for i, r in enumerate(relations['label'])}

    if limit_ > 0:
        a = df.groupby('from').count()['rel']
        b = df.groupby('to').count()['rel']

        # Filter out nodes with too few facts
        tmp = merge(
            right=DataFrame(a).reset_index(),
            left=DataFrame(b).reset_index(),
            how='outer',
            right_on='from',
            left_on='to',
        ).fillna(0)

        tmp['rel'] = tmp['rel_x'] + tmp['rel_y']
        tmp = tmp.drop(['from', 'rel_x', 'rel_y'], axis=1)

        tmp = tmp.loc[tmp['rel'] >= limit_]
        df_bis = df.loc[df['from'].isin(tmp['to']) | df['to'].isin(tmp['to'])]

        kg = KnowledgeGraph(df=df_bis, ent2ix=ent2ix, rel2ix=rel2ix)
    else:
        kg = KnowledgeGraph(df=df, ent2ix=ent2ix, rel2ix=rel2ix)

    return kg
Exemplo n.º 14
0
def load_fb13(data_home=None):
    """Load FB13 dataset.

    Parameters
    ----------
    data_home: str, optional
        Path to the `torchkge_data` directory (containing data folders). If
        files are not present on disk in this directory, they are downloaded
        and then placed in the right place.

    Returns
    -------
    kg_train: torchkge.data_structures.KnowledgeGraph
    kg_val: torchkge.data_structures.KnowledgeGraph
    kg_test: torchkge.data_structures.KnowledgeGraph

    """
    if data_home is None:
        data_home = get_data_home()
    data_path = data_home + '/FB13'
    if not exists(data_path):
        makedirs(data_path, exist_ok=True)
        urlretrieve(
            "https://graphs.telecom-paristech.fr/data/torchkge/kgs/FB13.zip",
            data_home + '/FB13.zip')
        with zipfile.ZipFile(data_home + '/FB13.zip', 'r') as zip_ref:
            zip_ref.extractall(data_home)
        remove(data_home + '/FB13.zip')

    df1 = read_csv(data_path + '/train2id.txt',
                   sep='\t',
                   header=None,
                   names=['from', 'rel', 'to'],
                   skiprows=[0])
    df2 = read_csv(data_path + '/valid2id.txt',
                   sep='\t',
                   header=None,
                   names=['from', 'rel', 'to'],
                   skiprows=[0])
    df3 = read_csv(data_path + '/test2id.txt',
                   sep='\t',
                   header=None,
                   names=['from', 'rel', 'to'],
                   skiprows=[0])
    ent2idx = read_csv(data_path + '/entity2id.txt',
                       sep='\t',
                       header=None,
                       names=['entity', 'idx'],
                       skiprows=[0])
    ent2idx = {e: i for e, i in zip(ent2idx['entity'], ent2idx['idx'])}
    rel2idx = read_csv(data_path + '/relation2id.txt',
                       sep='\t',
                       header=None,
                       names=['relation', 'idx'],
                       skiprows=[0])
    rel2idx = {r: i for r, i in zip(rel2idx['relation'], rel2idx['idx'])}

    df = concat([df1, df2, df3])
    kg = KnowledgeGraph(df=df, ent2ix=ent2idx, rel2ix=rel2idx)

    return kg.split_kg(sizes=(len(df1), len(df2), len(df3)))
Exemplo n.º 15
0
 def setUp(self):
     df = pd.DataFrame(
         [[0, 1, 0], [0, 2, 0], [0, 3, 0], [0, 4, 0], [1, 2, 1], [1, 3, 2],
          [2, 4, 0], [3, 4, 4], [5, 4, 0]],
         columns=['from', 'to', 'rel'])
     self.kg = KnowledgeGraph(df)
Exemplo n.º 16
0
 def test_get_hpt(self):
     kg = KnowledgeGraph(df=self.df)
     t = cat((kg.head_idx.view(-1, 1), kg.tail_idx.view(
         -1, 1), kg.relations.view(-1, 1)),
             dim=1)
     assert get_hpt(t) == {0: 1.5, 1: 1., 2: 1., 3: 1.}
Exemplo n.º 17
0
def load_wikidatasets(which, limit_=None, data_home=None):
    """Load WikiDataSets dataset. See `here
    <https://arxiv.org/abs/1906.04536>`__ for paper by Boschin et al.
    originally presenting the dataset.

    Parameters
    ----------
    which: str
        String indicating which subset of Wikidata should be loaded.
        Available ones are `humans`, `companies`, `animals`, `countries` and
        `films`.
    limit_: int, optional (default=0)
        This indicates a lower limit on the number of neighbors an entity
        should have in the graph to be kept.
    data_home: str, optional
        Path to the `torchkge_data` directory (containing data folders). If
        files are not present on disk in this directory, they are downloaded
        and then placed in the right place.

    Returns
    -------
    kg_train: torchkge.data_structures.KnowledgeGraph
    kg_val: torchkge.data_structures.KnowledgeGraph
    kg_test: torchkge.data_structures.KnowledgeGraph

    """
    assert which in ['humans', 'companies', 'animals', 'countries', 'films']

    if data_home is None:
        data_home = get_data_home()

    data_home = data_home + '/WikiDataSets'
    data_path = data_home + '/' + which
    if not exists(data_path):
        makedirs(data_path, exist_ok=True)
        urlretrieve(
            "https://graphs.telecom-paristech.fr/WikiDataSets/{}.tar.gz".
            format(which), data_home + '/{}.tar.gz'.format(which))

        with tarfile.open(data_home + '/{}.tar.gz'.format(which), 'r') as tf:
            tf.extractall(data_home)
        remove(data_home + '/{}.tar.gz'.format(which))

    df = read_csv(data_path + '/edges.txt'.format(which),
                  sep='\t',
                  header=1,
                  names=['from', 'to', 'rel'])

    a = df.groupby('from').count()['rel']
    b = df.groupby('to').count()['rel']

    # Filter out nodes with too few facts
    tmp = merge(
        right=DataFrame(a).reset_index(),
        left=DataFrame(b).reset_index(),
        how='outer',
        right_on='from',
        left_on='to',
    ).fillna(0)

    tmp['rel'] = tmp['rel_x'] + tmp['rel_y']
    tmp = tmp.drop(['from', 'rel_x', 'rel_y'], axis=1)

    tmp = tmp.loc[tmp['rel'] >= limit_]
    df_bis = df.loc[df['from'].isin(tmp['to']) | df['to'].isin(tmp['to'])]

    kg = KnowledgeGraph(df_bis)
    kg_train, kg_val, kg_test = kg.split_kg(share=0.8, validation=True)

    return kg_train, kg_val, kg_test