def run(): # load data ent2id = read_elements(os.path.join(config.data_path, "entities.dict")) rel2id = read_elements(os.path.join(config.data_path, "relations.dict")) ent_num = len(ent2id) rel_num = len(rel2id) train_triples = read_triples(os.path.join(config.data_path, "train.txt"), ent2id, rel2id) valid_triples = read_triples(os.path.join(config.data_path, "valid.txt"), ent2id, rel2id) test_triples = read_triples(os.path.join(config.data_path, "test.txt"), ent2id, rel2id) symmetry_test = read_triples( os.path.join(config.data_path, "symmetry_test.txt"), ent2id, rel2id) inversion_test = read_triples( os.path.join(config.data_path, "inversion_test.txt"), ent2id, rel2id) composition_test = read_triples( os.path.join(config.data_path, "composition_test.txt"), ent2id, rel2id) others_test = read_triples( os.path.join(config.data_path, "other_test.txt"), ent2id, rel2id) logging.info("#ent_num: %d" % ent_num) logging.info("#rel_num: %d" % rel_num) logging.info("#train triple num: %d" % len(train_triples)) logging.info("#valid triple num: %d" % len(valid_triples)) logging.info("#test triple num: %d" % len(test_triples)) logging.info("#Model: %s" % config.model) # 创建模型 kge_model = TransE(ent_num, rel_num) if config.model == "TransH": kge_model = TransH(ent_num, rel_num) elif config.model == "TransR": kge_model = SimpleTransR(ent_num, rel_num) elif config.model == "TransD": kge_model = TransD(ent_num, rel_num) elif config.model == "STransE": kge_model = STransE(ent_num, rel_num) elif config.model == "LineaRE": kge_model = LineaRE(ent_num, rel_num) elif config.model == "DistMult": kge_model = DistMult(ent_num, rel_num) elif config.model == "ComplEx": kge_model = ComplEx(ent_num, rel_num) elif config.model == "RotatE": kge_model = RotatE(ent_num, rel_num) elif config.model == "TransIJ": kge_model = TransIJ(ent_num, rel_num) kge_model = kge_model.cuda(torch.device("cuda:0")) logging.info("Model Parameter Configuration:") for name, param in kge_model.named_parameters(): logging.info("Parameter %s: %s, require_grad = %s" % (name, str(param.size()), str(param.requires_grad))) # 训练 train(model=kge_model, triples=(train_triples, valid_triples, test_triples, symmetry_test, inversion_test, composition_test, others_test), ent_num=ent_num)
def run(): set_logger() # load data ent_path = os.path.join(config.data_path, "entities.dict") rel_path = os.path.join(config.data_path, "relations.dict") ent2id = read_elements(ent_path) rel2id = read_elements(rel_path) ent_num = len(ent2id) rel_num = len(rel2id) train_triples = read_triples(os.path.join(config.data_path, "train.txt"), ent2id, rel2id) valid_triples = read_triples(os.path.join(config.data_path, "valid.txt"), ent2id, rel2id) test_triples = read_triples(os.path.join(config.data_path, "test.txt"), ent2id, rel2id) logging.info("#ent_num: %d" % ent_num) logging.info("#rel_num: %d" % rel_num) logging.info("#train triple num: %d" % len(train_triples)) logging.info("#valid triple num: %d" % len(valid_triples)) logging.info("#test triple num: %d" % len(test_triples)) logging.info("#Model: %s" % config.model) # 创建模型 kge_model = TransE(ent_num, rel_num) if config.model == "TransH": kge_model = TransH(ent_num, rel_num) elif config.model == "TransR": kge_model = TransR(ent_num, rel_num) elif config.model == "TransD": kge_model = TransD(ent_num, rel_num) elif config.model == "STransE": kge_model = STransE(ent_num, rel_num) elif config.model == "LineaRE": kge_model = LineaRE(ent_num, rel_num) elif config.model == "DistMult": kge_model = DistMult(ent_num, rel_num) elif config.model == "ComplEx": kge_model = ComplEx(ent_num, rel_num) elif config.model == "RotatE": kge_model = RotatE(ent_num, rel_num) if config.cuda: kge_model = kge_model.cuda() logging.info("Model Parameter Configuration:") for name, param in kge_model.named_parameters(): logging.info("Parameter %s: %s, require_grad = %s" % (name, str(param.size()), str(param.requires_grad))) # 训练 train(model=kge_model, triples=(train_triples, valid_triples, test_triples), ent_num=ent_num)