Ejemplo n.º 1
0
def test_generator_pairwise():
    """Function to test the generator for pairwise based algorithm."""
    knowledge_graph = KnowledgeGraph(dataset="freebase15k")
    knowledge_graph.force_prepare_data()

    config_def, model_def = Importer().import_model_config('transe')
    config = config_def(KGEArgParser().get_args([]))
    generator = Generator(model_def(**config.__dict__), config)
    generator.start_one_epoch(10)
    for i in range(10):
        data = list(next(generator))
        assert len(data) == 6

        ph = data[0]
        pr = data[1]
        pt = data[2]
        nh = data[3]
        nr = data[4]
        nt = data[5]
        assert len(ph) == len(pr)
        assert len(ph) == len(pt)
        assert len(ph) == len(nh)
        assert len(ph) == len(nr)
        assert len(ph) == len(nt)

    generator.stop()
Ejemplo n.º 2
0
def test_known_datasets(dataset_name):
    """Function to test the the knowledge graph parse for Freebase."""
    knowledge_graph = KnowledgeGraph(dataset=dataset_name)
    knowledge_graph.force_prepare_data()

    assert len(knowledge_graph.dump()) == 9
    assert knowledge_graph.is_cache_exists()

    kg_metadata = knowledge_graph.dataset.read_metadata()

    assert kg_metadata.tot_triple > 0
    assert kg_metadata.tot_valid_triples > 0
    assert kg_metadata.tot_test_triples > 0
    assert kg_metadata.tot_train_triples > 0
    assert kg_metadata.tot_relation > 0
    assert kg_metadata.tot_entity > 0

    assert len(knowledge_graph.read_cache_data('triplets_train')) > 0
    assert len(knowledge_graph.read_cache_data('triplets_test')) > 0
    assert len(knowledge_graph.read_cache_data('triplets_valid')) > 0
    assert len(knowledge_graph.read_cache_data('hr_t')) > 0
    assert len(knowledge_graph.read_cache_data('tr_h')) > 0
    assert len(knowledge_graph.read_cache_data('idx2entity')) > 0
    assert len(knowledge_graph.read_cache_data('idx2relation')) > 0
    assert len(knowledge_graph.read_cache_data('entity2idx')) > 0
    assert len(knowledge_graph.read_cache_data('relation2idx')) > 0
Ejemplo n.º 3
0
def test_fb15k_meta():
    """Function to test the the knowledge graph parse for Freebase and basic operations."""
    knowledge_graph = KnowledgeGraph(dataset="freebase15k")
    knowledge_graph.force_prepare_data()
    knowledge_graph.dump()

    assert knowledge_graph.is_cache_exists()
    knowledge_graph.prepare_data()

    knowledge_graph.dataset.read_metadata()
    knowledge_graph.dataset.dump()
Ejemplo n.º 4
0
def test_fb15k_manipulate():
    """Function to test the the knowledge graph parse for Freebase and basic operations."""
    knowledge_graph = KnowledgeGraph(dataset="freebase15k")
    knowledge_graph.force_prepare_data()
    knowledge_graph.dump()

    knowledge_graph.read_cache_data('triplets_train')
    knowledge_graph.read_cache_data('triplets_test')
    knowledge_graph.read_cache_data('triplets_valid')
    knowledge_graph.read_cache_data('hr_t')
    knowledge_graph.read_cache_data('tr_h')
    knowledge_graph.read_cache_data('idx2entity')
    knowledge_graph.read_cache_data('idx2relation')
    knowledge_graph.read_cache_data('entity2idx')
    knowledge_graph.read_cache_data('relation2idx')
Ejemplo n.º 5
0
def test_generator_pointwise():
    """Function to test the generator for pointwise based algorithm."""
    knowledge_graph = KnowledgeGraph(dataset="freebase15k")
    knowledge_graph.force_prepare_data()

    config_def, model_def = Importer().import_model_config("complex")
    config = config_def(KGEArgParser().get_args([]))
    generator = Generator(model_def(**config.__dict__), config)
    generator.start_one_epoch(10)
    for i in range(10):
        data = list(next(generator))
        assert len(data) == 4

        h = data[0]
        r = data[1]
        t = data[2]
        y = data[3]
        assert len(h) == len(r)
        assert len(h) == len(t)
        assert set(y) == {1, -1}

    generator.stop()
Ejemplo n.º 6
0
def test_generator_proje():
    """Function to test the generator for projection based algorithm."""
    knowledge_graph = KnowledgeGraph(dataset="freebase15k")
    knowledge_graph.force_prepare_data()

    config_def, model_def = Importer().import_model_config("proje_pointwise")
    config = config_def(KGEArgParser().get_args([]))
    generator = Generator(model_def(**config.__dict__), config)
    generator.start_one_epoch(10)
    for i in range(10):
        data = list(next(generator))
        assert len(data) == 5

        h = data[0]
        r = data[1]
        t = data[2]
        hr_t = data[3]
        tr_h = data[4]
        assert len(h) == len(r)
        assert len(h) == len(t)
        assert isinstance(hr_t, torch.Tensor)
        assert isinstance(tr_h, torch.Tensor)

    generator.stop()
Ejemplo n.º 7
0
def test_benchmarks(dataset_name):
    """Function to test the the knowledge graph parse for Freebase."""
    knowledge_graph = KnowledgeGraph(dataset=dataset_name)
    knowledge_graph.force_prepare_data()
    knowledge_graph.dump()