def main(): args = KGEArgParser().get_args(sys.argv[1:]) if Path(args.dataset_path).exists(): kdl = KnowledgeDataLoader(data_dir=args.dataset_path, negative_sampling=args.sampling) kg = kdl.get_knowledge_graph() print('Successfully loaded {} triples from {}.'.format( len(kdl.triples), kdl.data_dir)) else: print('Unable to find dataset from path:', args.dataset_path) print( 'Default loading Freebase15k dataset with default hyperparameters...' ) kg = KnowledgeGraph() kg.prepare_data() kg.dump() # TODO: Not sure why new dataset isn't cached on subsequent hits... args.dataset_path = './data/' + kg.dataset_name args.dataset_name = kg.dataset_name # Add new model configurations to run. models = [TransE(transe_config(args=args))] for model in models: print('---- Training Model: {} ----'.format(model.model_name)) trainer = Trainer(model=model, debug=args.debug) trainer.build_model() trainer.train_model() tf.reset_default_graph()
def test_userdefined_dataset(): custom_dataset_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "resource/custom_dataset") knowledge_graph = KnowledgeGraph(dataset="userdefineddataset", negative_sample="uniform", custom_dataset_path=custom_dataset_path) knowledge_graph.prepare_data() knowledge_graph.dump() knowledge_graph.read_cache_data('triplets_train') knowledge_graph.read_cache_data('triplets_test') knowledge_graph.read_cache_data('triplets_valid') knowledge_graph.read_cache_data('hr_t') knowledge_graph.read_cache_data('tr_h') knowledge_graph.read_cache_data('idx2entity') knowledge_graph.read_cache_data('idx2relation') knowledge_graph.read_cache_data('entity2idx') knowledge_graph.read_cache_data('relation2idx') knowledge_graph.dataset.read_metadata() knowledge_graph.dataset.dump() assert knowledge_graph.kg_meta.tot_train_triples == 1 assert knowledge_graph.kg_meta.tot_test_triples == 1 assert knowledge_graph.kg_meta.tot_valid_triples == 1 assert knowledge_graph.kg_meta.tot_entity == 6 assert knowledge_graph.kg_meta.tot_relation == 3
def test_fb15k_meta(): """Function to test the the knowledge graph parse for Freebase and basic operations.""" knowledge_graph = KnowledgeGraph(dataset="freebase15k") knowledge_graph.force_prepare_data() knowledge_graph.dump() assert knowledge_graph.is_cache_exists() knowledge_graph.prepare_data() knowledge_graph.dataset.read_metadata() knowledge_graph.dataset.dump()
def test_fb15k_manipulate(): """Function to test the the knowledge graph parse for Freebase and basic operations.""" knowledge_graph = KnowledgeGraph(dataset="freebase15k") knowledge_graph.force_prepare_data() knowledge_graph.dump() knowledge_graph.read_cache_data('triplets_train') knowledge_graph.read_cache_data('triplets_test') knowledge_graph.read_cache_data('triplets_valid') knowledge_graph.read_cache_data('hr_t') knowledge_graph.read_cache_data('tr_h') knowledge_graph.read_cache_data('idx2entity') knowledge_graph.read_cache_data('idx2relation') knowledge_graph.read_cache_data('entity2idx') knowledge_graph.read_cache_data('relation2idx')
def test_dl50a(): """Function to test the the knowledge graph parse for Deep learning knowledge base.""" knowledge_graph = KnowledgeGraph(dataset="deeplearning50a", negative_sample="uniform") knowledge_graph.force_prepare_data() knowledge_graph.dump()
def test_fb15k(): """Function to test the the knowledge graph parse for Freebase.""" knowledge_graph = KnowledgeGraph(dataset="freebase15k", negative_sample="uniform") knowledge_graph.force_prepare_data() knowledge_graph.dump()
def test_yago(): """Function to test the the knowledge graph parse for Yago Dataset.""" knowledge_graph = KnowledgeGraph(dataset="yago3_10", negative_sample="uniform") knowledge_graph.force_prepare_data() knowledge_graph.dump()
def test_wn18rr(): """Function to test the the knowledge graph parse for Wordnet Dataset.""" knowledge_graph = KnowledgeGraph(dataset="wordnet18_rr", negative_sample="uniform") knowledge_graph.force_prepare_data() knowledge_graph.dump()
def test_benchmarks(dataset_name): """Function to test the the knowledge graph parse for Freebase.""" print("testing downloading from online sources of benchmarks and KG controller's handling.") knowledge_graph = KnowledgeGraph(dataset=dataset_name) knowledge_graph.force_prepare_data() knowledge_graph.dump()