def test_generator_pairwise(): """Function to test the generator for pairwise based algorithm.""" knowledge_graph = KnowledgeGraph(dataset="freebase15k") knowledge_graph.force_prepare_data() config_def, model_def = Importer().import_model_config('transe') config = config_def(KGEArgParser().get_args([])) generator = Generator(model_def(**config.__dict__), config) generator.start_one_epoch(10) for i in range(10): data = list(next(generator)) assert len(data) == 6 ph = data[0] pr = data[1] pt = data[2] nh = data[3] nr = data[4] nt = data[5] assert len(ph) == len(pr) assert len(ph) == len(pt) assert len(ph) == len(nh) assert len(ph) == len(nr) assert len(ph) == len(nt) generator.stop()
def test_known_datasets(dataset_name): """Function to test the the knowledge graph parse for Freebase.""" knowledge_graph = KnowledgeGraph(dataset=dataset_name) knowledge_graph.force_prepare_data() assert len(knowledge_graph.dump()) == 9 assert knowledge_graph.is_cache_exists() kg_metadata = knowledge_graph.dataset.read_metadata() assert kg_metadata.tot_triple > 0 assert kg_metadata.tot_valid_triples > 0 assert kg_metadata.tot_test_triples > 0 assert kg_metadata.tot_train_triples > 0 assert kg_metadata.tot_relation > 0 assert kg_metadata.tot_entity > 0 assert len(knowledge_graph.read_cache_data('triplets_train')) > 0 assert len(knowledge_graph.read_cache_data('triplets_test')) > 0 assert len(knowledge_graph.read_cache_data('triplets_valid')) > 0 assert len(knowledge_graph.read_cache_data('hr_t')) > 0 assert len(knowledge_graph.read_cache_data('tr_h')) > 0 assert len(knowledge_graph.read_cache_data('idx2entity')) > 0 assert len(knowledge_graph.read_cache_data('idx2relation')) > 0 assert len(knowledge_graph.read_cache_data('entity2idx')) > 0 assert len(knowledge_graph.read_cache_data('relation2idx')) > 0
def test_fb15k_meta(): """Function to test the the knowledge graph parse for Freebase and basic operations.""" knowledge_graph = KnowledgeGraph(dataset="freebase15k") knowledge_graph.force_prepare_data() knowledge_graph.dump() assert knowledge_graph.is_cache_exists() knowledge_graph.prepare_data() knowledge_graph.dataset.read_metadata() knowledge_graph.dataset.dump()
def test_fb15k_manipulate(): """Function to test the the knowledge graph parse for Freebase and basic operations.""" knowledge_graph = KnowledgeGraph(dataset="freebase15k") knowledge_graph.force_prepare_data() knowledge_graph.dump() knowledge_graph.read_cache_data('triplets_train') knowledge_graph.read_cache_data('triplets_test') knowledge_graph.read_cache_data('triplets_valid') knowledge_graph.read_cache_data('hr_t') knowledge_graph.read_cache_data('tr_h') knowledge_graph.read_cache_data('idx2entity') knowledge_graph.read_cache_data('idx2relation') knowledge_graph.read_cache_data('entity2idx') knowledge_graph.read_cache_data('relation2idx')
def test_generator_pointwise(): """Function to test the generator for pointwise based algorithm.""" knowledge_graph = KnowledgeGraph(dataset="freebase15k") knowledge_graph.force_prepare_data() config_def, model_def = Importer().import_model_config("complex") config = config_def(KGEArgParser().get_args([])) generator = Generator(model_def(**config.__dict__), config) generator.start_one_epoch(10) for i in range(10): data = list(next(generator)) assert len(data) == 4 h = data[0] r = data[1] t = data[2] y = data[3] assert len(h) == len(r) assert len(h) == len(t) assert set(y) == {1, -1} generator.stop()
def test_generator_proje(): """Function to test the generator for projection based algorithm.""" knowledge_graph = KnowledgeGraph(dataset="freebase15k") knowledge_graph.force_prepare_data() config_def, model_def = Importer().import_model_config("proje_pointwise") config = config_def(KGEArgParser().get_args([])) generator = Generator(model_def(**config.__dict__), config) generator.start_one_epoch(10) for i in range(10): data = list(next(generator)) assert len(data) == 5 h = data[0] r = data[1] t = data[2] hr_t = data[3] tr_h = data[4] assert len(h) == len(r) assert len(h) == len(t) assert isinstance(hr_t, torch.Tensor) assert isinstance(tr_h, torch.Tensor) generator.stop()
def test_benchmarks(dataset_name): """Function to test the the knowledge graph parse for Freebase.""" knowledge_graph = KnowledgeGraph(dataset=dataset_name) knowledge_graph.force_prepare_data() knowledge_graph.dump()