Ejemplo n.º 1
0
def test_complex_mcl_pipeline():
    """

    """
    dataset = KgDataset()
    dataset.load_triples(triples)
    train_data = dataset.data["default"]
    kge_model = ComplEx_MCL()
    kge_model.fit(train_data)
    train_data_scores = kge_model.predict(train_data)
Ejemplo n.º 2
0
def test_distmult_mcl_pipeline():
    """

    """
    dataset = KgDataset()
    dataset.load_triples(triples)
    train_data = dataset.data["default"]
    kge_model = DistMult_MCL()
    kge_model.fit(train_data)
    train_data_scores = kge_model.predict(train_data)
Ejemplo n.º 3
0
def test_trimodel_pipeline():
    """

    """
    dataset = KgDataset()
    dataset.load_triples(triples)
    train_data = dataset.data["default"]
    kge_model = TriModel()
    kge_model.fit(train_data)
    train_data_scores = kge_model.predict(train_data)
Ejemplo n.º 4
0
def test_load_triplets():
    """

    """
    dataset = KgDataset()
    dataset.load_triples(triples)

    assert dataset.data["default"].shape == triples.shape

    dataset.load_triples(triples, "train")
    assert dataset.data["train"].shape == triples.shape
Ejemplo n.º 5
0
def test_ent_rel_counts():
    """

    """
    dataset = KgDataset()
    dataset.load_triples(triples)

    ent_count = dataset.get_ents_count()
    rel_count = dataset.get_rels_count()

    assert ent_count == 4
    assert rel_count == 3
Ejemplo n.º 6
0
def test_conversion():
    """

    """
    dataset = KgDataset()
    dataset.load_triples(triples)

    triples_idx = dataset.data["default"]
    triples_labels = dataset.indices2labels(triples_idx)
    triples_labels2idx = dataset.labels2indices(triples_labels)
    assert (triples_labels == triples).all()
    assert (triples_labels2idx == triples_idx).all()