def main(): parser = argparse.ArgumentParser(description='TransE') parser.add_argument('--data_dir', type=str, default='../data/FB15k/') parser.add_argument('--embedding_dim', type=int, default=200) parser.add_argument('--margin_value', type=float, default=1.0) parser.add_argument('--score_func', type=str, default='L1') parser.add_argument('--batch_size', type=int, default=4800) parser.add_argument('--learning_rate', type=float, default=0.001) parser.add_argument('--n_generator', type=int, default=24) parser.add_argument('--n_rank_calculator', type=int, default=24) parser.add_argument('--ckpt_dir', type=str, default='../ckpt/') parser.add_argument('--summary_dir', type=str, default='../summary/') parser.add_argument('--max_epoch', type=int, default=500) parser.add_argument('--eval_freq', type=int, default=10) args = parser.parse_args() print(args) kg = KnowledgeGraph(data_dir=args.data_dir) kge_model = TransE(kg=kg, embedding_dim=args.embedding_dim, margin_value=args.margin_value, score_func=args.score_func, batch_size=args.batch_size, learning_rate=args.learning_rate, n_generator=args.n_generator, n_rank_calculator=args.n_rank_calculator, max_epoch=args.max_epoch, eval_freq=args.eval_freq) gpu_config = tf.GPUOptions(allow_growth=True) sess_config = tf.ConfigProto(gpu_options=gpu_config) with tf.Session(config=sess_config) as sess: kge_model.train(sess)
def main(): parser = argparse.ArgumentParser(description='TransE') parser.add_argument('--data_dir', type=str, default='./data/') parser.add_argument('--embedding_dim', type=int, default=200) parser.add_argument('--margin_value', type=float, default=1.0) parser.add_argument('--score_func', type=str, default='L1') parser.add_argument('--batch_size', type=int, default=5000) parser.add_argument('--learning_rate', type=float, default=0.003) parser.add_argument('--n_generator', type=int, default=24) parser.add_argument('--n_rank_calculator', type=int, default=24) parser.add_argument('--ckpt_dir', type=str, default='../ckpt/') parser.add_argument('--summary_dir', type=str, default='../summary/') parser.add_argument('--max_epoch', type=int, default=500) parser.add_argument('--eval_freq', type=int, default=10000000) args = parser.parse_args() print(args) kg = KnowledgeGraph(data_dir=args.data_dir) kge_model = TransE(kg=kg, embedding_dim=args.embedding_dim, margin_value=args.margin_value, score_func=args.score_func, batch_size=args.batch_size, learning_rate=args.learning_rate, n_generator=args.n_generator, n_rank_calculator=args.n_rank_calculator) gpu_config = tf.GPUOptions(allow_growth=True) sess_config = tf.ConfigProto(gpu_options=gpu_config) with tf.Session(config=sess_config) as sess: print('-----Initializing tf graph-----') tf.global_variables_initializer().run() print('-----Initialization accomplished-----') kge_model.check_norm(session=sess) summary_writer = tf.summary.FileWriter(logdir=args.summary_dir, graph=sess.graph) for epoch in range(args.max_epoch): print('=' * 30 + '[EPOCH {}]'.format(epoch) + '=' * 30) kge_model.launch_training(session=sess, summary_writer=summary_writer) if (epoch + 1) % args.eval_freq == 0: kge_model.launch_evaluation(session=sess) if (epoch + 1) % 10 == 0: kge_model.save_embedding(session=sess)
def main(): parser = argparse.ArgumentParser(description='TransE') parser.add_argument('--data_dir', type=str, default='../data/FB15k/') parser.add_argument('--embedding_dim', type=int, default=200) parser.add_argument('--margin_value', type=float, default=1.0) parser.add_argument('--score_func', type=str, default='L1') parser.add_argument('--batch_size', type=int, default=4800) parser.add_argument('--learning_rate', type=float, default=0.001) parser.add_argument('--n_generator', type=int, default=24) parser.add_argument('--n_rank_calculator', type=int, default=24) parser.add_argument('--ckpt_dir', type=str, default='../ckpt/') parser.add_argument('--model_name', type=str) parser.add_argument('--summary_dir', type=str, default='../summary/') parser.add_argument('--max_epoch', type=int, default=500) parser.add_argument('--eval_freq', type=int, default=10) args = parser.parse_args() print(args) kg = KnowledgeGraph(data_dir=args.data_dir) kge_model = TransE(kg=kg, embedding_dim=args.embedding_dim, margin_value=args.margin_value, score_func=args.score_func, batch_size=args.batch_size, learning_rate=args.learning_rate, n_generator=args.n_generator, n_rank_calculator=args.n_rank_calculator, model_name=args.model_name, ckpt_dir=args.ckpt_dir) gpu_config = tf.GPUOptions(allow_growth=False) sess_config = tf.ConfigProto(gpu_options=gpu_config) with tf.Session(config=sess_config) as sess: print('-----Initializing tf graph-----') tf.global_variables_initializer().run() print('-----Initialization accomplished-----') kge_model.check_norm(session=sess) summary_writer = tf.summary.FileWriter(logdir=args.summary_dir, graph=sess.graph) saver = tf.train.Saver(tf.global_variables(), max_to_keep=500) for epoch in range(args.max_epoch): print('=' * 30 + '[EPOCH {}]'.format(epoch) + '=' * 30) kge_model.launch_training(session=sess, summary_writer=summary_writer) if (epoch + 1) % args.eval_freq == 0: kge_model.launch_evaluation(session=sess, saver=saver) print('-----Save checkpoint-----') step_str = str(kge_model.global_step.eval(session=sess)) save_path = args.ckpt_dir + '/' + args.model_name + step_str + '.ckpt' saver_path = saver.save(sess, save_path) tf.saved_model.simple_save( sess, args.ckpt_dir + '/model-' + step_str, inputs={'triple': kge_model.eval_triple}, outputs={ 'entity-embedding': kge_model.entity_embedding, 'relation-embedding': kge_model.relation_embedding }) print("Model saved in path: %s" % saver_path)
def transe(data_path, embedding_dims, margin_value, score_func, batch_size, learning_rate, n_generator, n_rank_calculator, max_epoch): kg = KnowledgeGraph(data_dir=data_path) kge_model = TransE(kg=kg, embedding_dim=embedding_dims, margin_value=margin_value, score_func=score_func, batch_size=batch_size, learning_rate=learning_rate, n_generator=n_generator, n_rank_calculator=n_rank_calculator) gpu_config = tf.GPUOptions(allow_growth=True) sess_config = tf.ConfigProto(gpu_options=gpu_config) with tf.Session(config=sess_config) as sess: print('-----Initializing tf graph-----') tf.global_variables_initializer().run() print('-----Initialization accomplished-----') entity_embedding, relation_embedding = kge_model.check_norm( session=sess) summary_writer = tf.summary.FileWriter(logdir='../summary/', graph=sess.graph) for epoch in range(max_epoch): print('=' * 30 + '[EPOCH {}]'.format(epoch) + '=' * 30) kge_model.launch_training(session=sess, summary_writer=summary_writer) if (epoch + 1) % 10 == 0: kge_model.launch_evaluation(session=sess) return entity_embedding, relation_embedding
def main(): parser = argparse.ArgumentParser(description='TransE') parser.add_argument('--data_dir', type=str, default='../data/FB15k/') parser.add_argument('--embedding_dim', type=int, default=200) parser.add_argument('--margin_value', type=float, default=1.0) parser.add_argument('--score_func', type=str, default='L1') parser.add_argument('--batch_size', type=int, default=4800) parser.add_argument('--learning_rate', type=float, default=0.001) parser.add_argument('--n_generator', type=int, default=24) parser.add_argument('--n_rank_calculator', type=int, default=24) parser.add_argument('--ckpt_dir', type=str, default='../ckpt/') parser.add_argument('--model_name', type=str) parser.add_argument('--summary_dir', type=str, default='../summary/') parser.add_argument('--max_epoch', type=int, default=500) parser.add_argument('--eval_freq', type=int, default=10) args = parser.parse_args() print(args) kg = KnowledgeGraph(data_dir=args.data_dir) kge_model = TransE(kg=kg, embedding_dim=args.embedding_dim, margin_value=args.margin_value, score_func=args.score_func, batch_size=args.batch_size, learning_rate=args.learning_rate, n_generator=args.n_generator, n_rank_calculator=args.n_rank_calculator, model_name=args.model_name, ckpt_dir=args.ckpt_dir) gpu_config = tf.GPUOptions(allow_growth=False) sess_config = tf.ConfigProto(gpu_options=gpu_config) tf.reset_default_graph() with tf.Session(config=sess_config) as sess: print('-----Initializing tf graph-----') tf.global_variables_initializer().run() print('-----Initialization accomplished-----') kge_model.check_norm(session=sess) # summary_writer = tf.summary.FileWriter(logdir=args.summary_dir, graph=sess.graph) saver = tf.train.Saver() saver.restore(sess, '../checkpoints/gov-g2/GOV-g2130000.ckpt.index') # for epoch in range(args.max_epoch): # print('=' * 30 + '[EPOCH {}]'.format(epoch) + '=' * 30) # kge_model.launch_training(session=sess, summary_writer=summary_writer) # if (epoch + 1) % args.eval_freq == 0: # kge_model.launch_evaluation(session=sess, saver=saver) print('-----Model Loaded-----')
def transe(id,data_path,embedding_dims,margin_value,score_func,batch_size,learning_rate,n_generator,n_rank_calculator,max_epoch,d): kg = KnowledgeGraph(data_dir=data_path) content = [] kge_model = TransE(kg=kg, embedding_dim=embedding_dims, margin_value=margin_value, score_func=score_func, batch_size=batch_size, learning_rate=learning_rate, n_generator=n_generator, n_rank_calculator=n_rank_calculator,id=id) gpu_config = tf.GPUOptions(allow_growth=True) sess_config = tf.ConfigProto(gpu_options=gpu_config) with tf.Session(config=sess_config) as sess: # print('-----Initializing tf graph-----') tf.global_variables_initializer().run() # print('-----Initialization accomplished-----') # loss,entity_embedding,relation_embedding = kge_model.check_norm(session=sess) summary_writer = tf.summary.FileWriter(logdir='../summary/', graph=sess.graph) for epoch in range(max_epoch): # print('=' * 30 + '[EPOCH {}]'.format(epoch) + '=' * 30) # print(loss) kge_model.launch_training(session=sess, summary_writer=summary_writer) if (epoch + 1) % 50 == 0: kge_model.launch_evaluation(session=sess) loss,entity_embedding,relation_embedding = kge_model.check_norm(session=sess) content.append(loss) content.append(entity_embedding) content.append(relation_embedding) # print(relation_embedding.shape) d[id] = content print('FB15k-{} loss:{}'.format(id,d[id][0])) # print(type(d)) return entity_embedding,relation_embedding
def main(): parser = argparse.ArgumentParser(description='TransE') parser.add_argument('--data_dir', type=str, default='../data/after_big/') parser.add_argument('--embedding_dim', type=int, default=200) parser.add_argument('--margin_value', type=float, default=1.0) parser.add_argument('--score_func', type=str, default='L1') parser.add_argument('--batch_size', type=int, default=4800) parser.add_argument('--eval_batch_size', type=int, default=200) parser.add_argument('--learning_rate', type=float, default=0.001) parser.add_argument('--n_generator', type=int, default=24) parser.add_argument('--n_rank_calculator', type=int, default=24) parser.add_argument('--ckpt_dir', type=str, default='../ckpt/') parser.add_argument('--summary_dir', type=str, default='../summary/') parser.add_argument('--max_epoch', type=int, default=1000) parser.add_argument('--eval_freq', type=int, default=200) parser.add_argument('--log_file', type=str, default='../log/log_after_big.txt') args = parser.parse_args() print(args) file_object = open(args.log_file, 'w') file_object.close() fb15k = Dataset(data_dir=args.data_dir, log_file=args.log_file) kge_model = TransE(dataset=fb15k, embedding_dim=args.embedding_dim, margin_value=args.margin_value, score_func=args.score_func, batch_size=args.batch_size, eval_batch_size=args.eval_batch_size, learning_rate=args.learning_rate, n_generator=args.n_generator, n_rank_calculator=args.n_rank_calculator, log_file=args.log_file) gpu_config = tf.GPUOptions(allow_growth=True) sess_config = tf.ConfigProto(gpu_options=gpu_config) with tf.Session(config=sess_config) as sess: print('-----Initializing tf graph-----') file_object = open(args.log_file, 'a') file_object.write('-----Initializing tf graph-----\r\n') file_object.close() tf.global_variables_initializer().run() print('-----Initialization accomplished-----') print('----Check norm----') file_object = open(args.log_file, 'a') file_object.write('-----Initialization accomplished-----\r\n') file_object.write('----Check norm----\r\n') file_object.close() entity_embedding = kge_model.entity_embedding.eval(session=sess) relation_embedding = kge_model.relation_embedding.eval(session=sess) entity_norm = np.linalg.norm(entity_embedding, ord=2, axis=1) relation_norm = np.linalg.norm(relation_embedding, ord=2, axis=1) print('entity norm: {} relation norm: {}'.format( entity_norm, relation_norm)) file_object = open(args.log_file, 'a') file_object.write('entity norm: {} relation norm: {}'.format( entity_norm, relation_norm) + '\r\n') file_object.close() summary_writer = tf.summary.FileWriter(logdir=args.summary_dir, graph=sess.graph) saver = tf.train.Saver(max_to_keep=1) for epoch in range(args.max_epoch): print('=' * 30 + '[EPOCH {}]'.format(epoch) + '=' * 30) file_object = open(args.log_file, 'a') file_object.write('=' * 30 + '[EPOCH {}]'.format(epoch) + '=' * 30 + '\r\n') file_object.close() kge_model.launch_training(session=sess, summary_writer=summary_writer) saver.save(sess, '../ckpt/after_big.ckpt', global_step=epoch + 1) if (epoch + 1) % args.eval_freq == 0: kge_model.launch_evaluation(session=sess)
def main(): parser = argparse.ArgumentParser(description='TransE') parser.add_argument('--data_dir', type=str, default='../data/') parser.add_argument('--embedding_dim', type=int, default=200) parser.add_argument('--margin_value', type=float, default=1.0) parser.add_argument('--score_func', type=str, default='L1') parser.add_argument('--batch_size', type=int, default=4800) parser.add_argument('--learning_rate', type=float, default=0.001) parser.add_argument('--n_generator', type=int, default=24) parser.add_argument('--n_rank_calculator', type=int, default=24) parser.add_argument('--ckpt_dir', type=str, default='../ckpt/') parser.add_argument('--summary_dir', type=str, default='../summary/') parser.add_argument('--max_epoch', type=int, default=500) parser.add_argument('--eval_freq', type=int, default=10) args = parser.parse_args() print(args) ''' 传递参数 Namespace( batch_size=4800, ckpt_dir='../ckpt/', data_dir='../data/FB15k/', embedding_dim=200, eval_freq=10, learning_rate=0.001, margin_value=1.0, max_epoch=500, n_generator=24, n_rank_calculator=24, score_func='L1', //summary_dir='../summary/') ''' kg = KnowledgeGraph(data_dir=args.data_dir)#create graph kge_model = TransE( kg=kg, embedding_dim=args.embedding_dim, margin_value=args.margin_value, score_func=args.score_func, batch_size=args.batch_size, learning_rate=args.learning_rate, n_generator=args.n_generator, n_rank_calculator=args.n_rank_calculator )#init embd... gpu_config = tf.GPUOptions(allow_growth=True) ''' tf.GPUOptions:可以作为设置tf.ConfigProto时的一个参数选项,一般用于限制GPU资源的使用 allow_growth=True:动态申请现显存 ''' sess_config = tf.ConfigProto(gpu_options=gpu_config) ''' tf.ConfigProto: 创建session的时候,用来对session进行参数配置 ''' with tf.Session(config = sess_config) as sess: ''' Session 是 Tensorflow 为了控制,和输出文件的执行的语句. 运行 session.run() 可以获得你要得知的运算结果 ''' print('-----Initializing tf graph-----') tf.global_variables_initializer().run()#就是 run了 所有global Variable 的 assign op,这就是初始化参数的本来面目。 print('-----Initialization accomplished-----') kge_model.check_norm(session=sess) summary_writer = tf.summary.FileWriter(logdir=args.summary_dir, graph=sess.graph) # print(type(sess.graph)) # print(type(summary_writer)) '''
link_dic = {} entity_idx = 0 link_idx = 0 for row in train_tsv_reader: head = row[0] link = row[1] tail = row[2] if not head in entity_dic: entity_dic[head] = entity_idx entity_idx += 1 if not tail in entity_dic: entity_dic[tail] = entity_idx entity_idx += 1 if not link in link_dic: link_dic[link] = link_idx link_idx += 1 data.append((entity_dic[head], link_dic[link], entity_dic[tail])) return data, entity_dic, link_dic train_data, entity_dic, link_dic = load(TRAIN_DATASET_PATH) #training model = TransE(len(entity_dic), len(link_dic), 1, 50, 0.01, 50) model.fit(np.array(train_data)) with open('models/transe.pkl', 'wb') as f: pickle.dump(model, f)
def prepareModel(self): print("Perpare model") self.model = TransE(self.n_entities, self.n_relations, embDim=100) if GPU: self.model.cuda()
class Train: def __init__(self, data_name): self.dataset = get_dataset(data_name) self.n_entities = self.dataset.n_entities self.n_relations = self.dataset.n_relations def prepareData(self): print("Perpare dataloader") self.train = TrainDataset(self.dataset) self.trainloader = None self.valid = EvalDataset(self.dataset) self.validloader = DataLoader(self.valid, batch_size=self.valid.n_triples, shuffle=False) def prepareModel(self): print("Perpare model") self.model = TransE(self.n_entities, self.n_relations, embDim=100) if GPU: self.model.cuda() def saveModel(self): pickle.dump(self.model.get_emb_weights(), open('emb_weight.pkl', 'wb')) def fit(self): optim = torch.optim.Adam(self.model.parameters(), lr=LR) minLoss = float("inf") bestMR = float("inf") GlobalEpoch = 0 for seed in range(100): print(f"# Using seed: {seed}") self.train.regenerate_neg_samples(seed=seed) self.trainloader = DataLoader(self.train, batch_size=1024, shuffle=True, num_workers=4) for epoch in range(EPOCHS_PER_SEED): GlobalEpoch += 1 for sample in self.trainloader: if GPU: pos_triples = torch.LongTensor( sample['pos_triples']).cuda() neg_triples = torch.LongTensor( sample['neg_triples']).cuda() else: pos_triples = torch.LongTensor(sample['pos_triples']) neg_triples = torch.LongTensor(sample['neg_triples']) self.model.normal_emb() loss = self.model(pos_triples, neg_triples) if GPU: lossVal = loss.cpu().item() else: lossVal = loss.item() optim.zero_grad() loss.backward() optim.step() if minLoss > lossVal: minLoss = lossVal MR = Eval_MR(self.validloader, "L2", **self.model.get_emb_weights()) if MR < bestMR: bestMR = MR print('save embedding weight') self.saveModel() print( f"Epoch: {epoch + 1}, Total_Train: {GlobalEpoch}, Loss: {lossVal}, minLoss: {minLoss}," f"MR: {MR}, bestMR: {bestMR}") if GlobalEpoch % LR_DECAY_EPOCH == 0: adjust_learning_rate(optim, 0.96)
def run(): set_logger() # load data ent_path = os.path.join(config.data_path, "entities.dict") rel_path = os.path.join(config.data_path, "relations.dict") ent2id = read_elements(ent_path) rel2id = read_elements(rel_path) ent_num = len(ent2id) rel_num = len(rel2id) train_triples = read_triples(os.path.join(config.data_path, "train.txt"), ent2id, rel2id) valid_triples = read_triples(os.path.join(config.data_path, "valid.txt"), ent2id, rel2id) test_triples = read_triples(os.path.join(config.data_path, "test.txt"), ent2id, rel2id) logging.info("#ent_num: %d" % ent_num) logging.info("#rel_num: %d" % rel_num) logging.info("#train triple num: %d" % len(train_triples)) logging.info("#valid triple num: %d" % len(valid_triples)) logging.info("#test triple num: %d" % len(test_triples)) logging.info("#Model: %s" % config.model) # 创建模型 kge_model = TransE(ent_num, rel_num) if config.model == "TransH": kge_model = TransH(ent_num, rel_num) elif config.model == "TransR": kge_model = TransR(ent_num, rel_num) elif config.model == "TransD": kge_model = TransD(ent_num, rel_num) elif config.model == "STransE": kge_model = STransE(ent_num, rel_num) elif config.model == "LineaRE": kge_model = LineaRE(ent_num, rel_num) elif config.model == "DistMult": kge_model = DistMult(ent_num, rel_num) elif config.model == "ComplEx": kge_model = ComplEx(ent_num, rel_num) elif config.model == "RotatE": kge_model = RotatE(ent_num, rel_num) if config.cuda: kge_model = kge_model.cuda() logging.info("Model Parameter Configuration:") for name, param in kge_model.named_parameters(): logging.info("Parameter %s: %s, require_grad = %s" % (name, str(param.size()), str(param.requires_grad))) # 训练 train(model=kge_model, triples=(train_triples, valid_triples, test_triples), ent_num=ent_num)
def run(): # load data ent2id = read_elements(os.path.join(config.data_path, "entities.dict")) rel2id = read_elements(os.path.join(config.data_path, "relations.dict")) ent_num = len(ent2id) rel_num = len(rel2id) train_triples = read_triples(os.path.join(config.data_path, "train.txt"), ent2id, rel2id) valid_triples = read_triples(os.path.join(config.data_path, "valid.txt"), ent2id, rel2id) test_triples = read_triples(os.path.join(config.data_path, "test.txt"), ent2id, rel2id) symmetry_test = read_triples( os.path.join(config.data_path, "symmetry_test.txt"), ent2id, rel2id) inversion_test = read_triples( os.path.join(config.data_path, "inversion_test.txt"), ent2id, rel2id) composition_test = read_triples( os.path.join(config.data_path, "composition_test.txt"), ent2id, rel2id) others_test = read_triples( os.path.join(config.data_path, "other_test.txt"), ent2id, rel2id) logging.info("#ent_num: %d" % ent_num) logging.info("#rel_num: %d" % rel_num) logging.info("#train triple num: %d" % len(train_triples)) logging.info("#valid triple num: %d" % len(valid_triples)) logging.info("#test triple num: %d" % len(test_triples)) logging.info("#Model: %s" % config.model) # 创建模型 kge_model = TransE(ent_num, rel_num) if config.model == "TransH": kge_model = TransH(ent_num, rel_num) elif config.model == "TransR": kge_model = SimpleTransR(ent_num, rel_num) elif config.model == "TransD": kge_model = TransD(ent_num, rel_num) elif config.model == "STransE": kge_model = STransE(ent_num, rel_num) elif config.model == "LineaRE": kge_model = LineaRE(ent_num, rel_num) elif config.model == "DistMult": kge_model = DistMult(ent_num, rel_num) elif config.model == "ComplEx": kge_model = ComplEx(ent_num, rel_num) elif config.model == "RotatE": kge_model = RotatE(ent_num, rel_num) elif config.model == "TransIJ": kge_model = TransIJ(ent_num, rel_num) kge_model = kge_model.cuda(torch.device("cuda:0")) logging.info("Model Parameter Configuration:") for name, param in kge_model.named_parameters(): logging.info("Parameter %s: %s, require_grad = %s" % (name, str(param.size()), str(param.requires_grad))) # 训练 train(model=kge_model, triples=(train_triples, valid_triples, test_triples, symmetry_test, inversion_test, composition_test, others_test), ent_num=ent_num)
def main(): parser = argparse.ArgumentParser(description='TransE') parser.add_argument('--mode', type=str, default='eval') parser.add_argument('--data_dir', type=str, default='../data/FB15k/') parser.add_argument('--embedding_dim', type=int, default=200) parser.add_argument('--margin_value', type=float, default=1.0) parser.add_argument('--score_func', type=str, default='L1') parser.add_argument('--batch_size', type=int, default=4800) parser.add_argument('--learning_rate', type=float, default=0.001) parser.add_argument('--n_generator', type=int, default=24) parser.add_argument('--n_rank_calculator', type=int, default=24) parser.add_argument('--hit_at_n', type=int, default=10) parser.add_argument('--ckpt_dir', type=str, default='../ckpt/') parser.add_argument('--summary_dir', type=str, default='../summary/') parser.add_argument('--max_epoch', type=int, default=500) parser.add_argument('--eval_freq', type=int, default=10) args = parser.parse_args() print(args) kg = KnowledgeGraph(data_dir=args.data_dir) kge_model = TransE(kg=kg, model_path=args.ckpt_dir, embedding_dim=args.embedding_dim, margin_value=args.margin_value, score_func=args.score_func, batch_size=args.batch_size, learning_rate=args.learning_rate, n_generator=args.n_generator, n_rank_calculator=args.n_rank_calculator, hit_at_n=args.hit_at_n) gpu_config = tf.GPUOptions(allow_growth=True) sess_config = tf.ConfigProto(gpu_options=gpu_config) if args.mode == 'test': saver = tf.train.Saver() else: saver = tf.train.Saver(tf.global_variables()) with tf.Session(config=sess_config) as sess: if args.mode == 'eval': print('-----Loading from checkpoints-----') ckpt_file = tf.train.latest_checkpoint(args.ckpt_dir) saver.restore(sess, ckpt_file) kge_model.launch_evaluation(session=sess) elif args.mode == 'train': print('-----Initializing tf graph-----') tf.global_variables_initializer().run() print('-----Initialization accomplished-----') kge_model.check_norm(session=sess) summary_writer = tf.summary.FileWriter(logdir=args.summary_dir, graph=sess.graph) for epoch in range(args.max_epoch): print('=' * 30 + '[EPOCH {}]'.format(epoch) + '=' * 30) kge_model.launch_training(epoch=epoch, session=sess, summary_writer=summary_writer, saver=saver) # if (epoch + 1) % args.eval_freq == 0: # kge_model.launch_evaluation(session=sess) elif args.mode == 'predict': print('-----Loading from checkpoints-----') ckpt_file = tf.train.latest_checkpoint(args.ckpt_dir) saver.restore(sess, ckpt_file) kge_model.launch_prediction(session=sess) else: print('Wrong mode!! (mode = train|test|predict)')
def main(): opts = get_train_args() print("load data ...") data = DataSet('data/modified_triples.txt') dataloader = DataLoader(data, shuffle=True, batch_size=opts.batch_size) print("load model ...") if opts.model_type == 'transe': model = TransE(opts, data.ent_tot, data.rel_tot) elif opts.model_type == "distmult": model = DistMult(opts, data.ent_tot, data.rel_tot) if opts.optimizer == 'Adam': optimizer = optim.Adam(model.parameters(), lr=opts.lr) elif opts.optimizer == 'SGD': optimizer = optim.SGD(model.parameters(), lr=opts.lr) model.cuda() model.relation_normalize() loss = torch.nn.MarginRankingLoss(margin=opts.margin) print("start training") for epoch in range(1, opts.epochs + 1): print("epoch : " + str(epoch)) model.train() epoch_start = time.time() epoch_loss = 0 tot = 0 cnt = 0 for i, batch_data in enumerate(dataloader): optimizer.zero_grad() batch_h, batch_r, batch_t, batch_n = batch_data batch_h = torch.LongTensor(batch_h).cuda() batch_r = torch.LongTensor(batch_r).cuda() batch_t = torch.LongTensor(batch_t).cuda() batch_n = torch.LongTensor(batch_n).cuda() pos_score, neg_score, dist = model.forward(batch_h, batch_r, batch_t, batch_n) pos_score = pos_score.cpu() neg_score = neg_score.cpu() dist = dist.cpu() train_loss = loss(pos_score, neg_score, torch.ones(pos_score.size(-1))) + dist train_loss.backward() optimizer.step() batch_loss = torch.sum(train_loss) epoch_loss += batch_loss batch_size = batch_h.size(0) tot += batch_size cnt += 1 print('\r{:>10} epoch {} progress {} loss: {}\n'.format( '', epoch, tot / data.__len__(), train_loss), end='') end = time.time() time_used = end - epoch_start epoch_loss /= cnt print('one epoch time: {} minutes'.format(time_used / 60)) print('{} epochs'.format(epoch)) print('epoch {} loss: {}'.format(epoch, epoch_loss)) if epoch % opts.save_step == 0: print("save model...") model.entity_normalize() torch.save(model.state_dict(), 'model.pt') print("save model...") model.entity_normalize() torch.save(model.state_dict(), 'model.pt') print("[Saving embeddings of whole entities & relations...]") save_embeddings(model, opts, data.id2ent, data.id2rel) print("[Embedding results are saved successfully.]")