def train(): # dataloader for training train_dataloader = TrainDataLoader(in_path='./data/kg/', nbatches=100, threads=8, sampling_mode="normal", bern_flag=1, filter_flag=1, neg_ent=25, neg_rel=0) # define the model transe = TransE(ent_tot=train_dataloader.get_ent_tot(), rel_tot=train_dataloader.get_rel_tot(), dim=Config.entity_embedding_dim, p_norm=1, norm_flag=True) # define the loss function model = NegativeSampling(model=transe, loss=MarginLoss(margin=5.0), batch_size=train_dataloader.get_batch_size()) # train the model trainer = Trainer(model=model, data_loader=train_dataloader, train_times=1000, alpha=1.0, use_gpu=True) trainer.run() transe.save_checkpoint('./data/kg/transe.ckpt')
def generate(): # dataloader for training train_dataloader = TrainDataLoader(in_path='./data/kg/', nbatches=100, threads=8, sampling_mode="normal", bern_flag=1, filter_flag=1, neg_ent=25, neg_rel=0) # define the model transe = TransE(ent_tot=train_dataloader.get_ent_tot(), rel_tot=train_dataloader.get_rel_tot(), dim=Config.entity_embedding_dim, p_norm=1, norm_flag=True) transe.load_checkpoint('./data/kg/transe.ckpt') entity_embedding = transe.get_parameters()['ent_embeddings.weight'] entity_embedding[0] = 0 np.save('./data/kg/entity.npy', entity_embedding) context_embedding = np.empty_like(entity_embedding) context_embedding[0] = 0 relation = pd.read_table('./data/sub_kg/triple2id.txt', header=None)[[0, 1]] entity = pd.read_table('./data/sub_kg/entity2name.txt', header=None)[[0]].to_numpy().flatten() for e in entity: df = pd.concat( [relation[relation[0] == e], relation[relation[1] == e]]) context = list(set(np.append(df.to_numpy().flatten(), e))) context_embedding[e] = np.mean(entity_embedding[context, :], axis=0) np.save('./data/kg/context.npy', context_embedding)
extract_path_vec_list = [] with open("benchmarks/FKB/relation2id.txt") as f: f.readline() for line in f.readlines(): extract_path_vec_list.append(path_vec_list[int(line.split('\t')[0])]) f.close() rel_embedding = nn.Embedding.from_pretrained( torch.from_numpy( np.array(extract_path_vec_list).astype(dtype='float64')).float()) # define the model transe = TransE(ent_tot=train_dataloader.get_ent_tot(), rel_tot=train_dataloader.get_rel_tot(), dim=50, p_norm=1, norm_flag=True) transe.load_rel_embeddings(rel_embedding) # define the loss function model = NegativeSampling(model=transe, loss=MarginLoss(margin=10.0), batch_size=train_dataloader.get_batch_size()) for k, v in model.named_parameters(): if k == 'model.rel_embeddings.weight': v.requires_grad = False # train the model
filter_flag=1, neg_ent=1, neg_rel=0) # dataloader for test #test_dataloader = TestDataLoader("../openke_data", "link") pretrain_init = { 'entity': '../concept_glove.max.npy', 'relation': '../relation_glove.max.npy' } # define the model transe = TransE(ent_tot=train_dataloader.get_ent_tot(), rel_tot=train_dataloader.get_rel_tot(), dim=100, p_norm=1, margin=1.0, norm_flag=True, init='pretrain', init_weights=pretrain_init) # define the loss function model = NegativeSampling(model=transe, loss=SigmoidLoss(adv_temperature=1), batch_size=train_dataloader.get_batch_size()) # train the model checkpoint_dir = Path('./checkpoint/') checkpoint_dir.mkdir(exist_ok=True, parents=True) alpha = 0.001 trainer = Trainer(model=model, data_loader=train_dataloader,
"--embedding", default=os.path.join(os.path.curdir, "kg_embed"), help="Path to saving embeddings") args = parser.parse_args() bench_path, ckpt_path, emb_path = args.benchmark, args.checkpoint, args.embedding # dataloader for training train_dataloader = TrainDataLoader(in_path=bench_path, nbatches=100, threads=16, sampling_mode="normal", bern_flag=1, filter_flag=1, neg_ent=25, neg_rel=0) # define the model transe = TransE(ent_tot=train_dataloader.get_ent_tot(), rel_tot=train_dataloader.get_rel_tot(), dim=100, p_norm=1, norm_flag=True) transe.load_checkpoint(os.path.join(ckpt_path, "transe.ckpt")) params = transe.get_parameters() np.savetxt(os.path.join(emb_path, "entity2vec.vec"), params["ent_embeddings.weight"]) np.savetxt(os.path.join(emb_path, "relation2vec.vec"), params["rel_embeddings.weight"])
TASK_REV_MEDIUMHAND, TASK_LABELS, ) import metrics from utils import Task, openke_predict, get_entity_relationship_dicts parser = argparse.ArgumentParser() parser.add_argument("--model", default='transe') args = parser.parse_args() ent_list, rel_list = get_entity_relationship_dicts() if args.model == 'transe': model = TransE(ent_tot=len(ent_list), rel_tot=len(rel_list), dim=200, p_norm=1, norm_flag=True) elif args.model == 'transd': model = TransD(ent_tot=len(ent_list), rel_tot=len(rel_list), dim_e=200, dim_r=200, p_norm=1, norm_flag=True) elif args.model == 'rescal': model = RESCAL(ent_tot=len(ent_list), rel_tot=len(rel_list), dim=50) elif args.model == 'distmult': model = DistMult(ent_tot=len(ent_list), rel_tot=len(rel_list), dim=200) elif args.model == 'complex': model = ComplEx(ent_tot=len(ent_list), rel_tot=len(rel_list), dim=200)
train_dataloader = TrainDataLoader( #in_path = "./benchmarks/transe_ske/", in_path=base_root, nbatches=100, threads=8, sampling_mode="normal", bern_flag=1, filter_flag=1, neg_ent=25, neg_rel=5) # define the model transe = TransE(ent_tot=train_dataloader.get_ent_tot(), rel_tot=train_dataloader.get_rel_tot(), dim=200, p_norm=2, norm_flag=True) save_path = os.path.join('checkpoint', phase, 'transe.ckpt') transe.load_checkpoint(save_path) rel_emb = transe.get_parameters()['rel_embeddings.weight'] ent_emb = transe.get_parameters()['ent_embeddings.weight'] e_emb, r_emb = dict(), dict() with open(entity2id_path, 'r', encoding='utf-8') as f: next(f) for line in f: tmp = line.split('\t') entity = ''.join(tmp[:-1]) e_emb[entity] = ent_emb[int(tmp[1]), :]
from openke.data import TrainDataLoader, TestDataLoader import pickle import pathlib # # dataloader for training train_dataloader = TrainDataLoader(in_path="./dbpedia50_openKE/kb2E/", nbatches=100, threads=8, bern_flag=1) # dataloader for test test_dataloader = TestDataLoader("./dbpedia50_openKE/kb2E/", "link") # define the model transe = TransE(ent_tot=train_dataloader.get_ent_tot(), rel_tot=train_dataloader.get_rel_tot(), dim=300) # define the loss function model = NegativeSampling(model=transe, loss=MarginLoss(), batch_size=train_dataloader.get_batch_size()) # train the model trainer = Trainer(model=model, data_loader=train_dataloader, train_times=1000, alpha=0.01, use_gpu=True, opt_method='adagrad') trainer.run()
train_dataloader = TrainDataLoader(in_path="./benchmarks/WN18RR/", batch_size=2000, threads=8, sampling_mode="cross", bern_flag=0, filter_flag=1, neg_ent=64, neg_rel=0) # dataloader for test test_dataloader = TestDataLoader("./benchmarks/WN18RR/", "link") # define the model transe = TransE(ent_tot=train_dataloader.get_ent_tot(), rel_tot=train_dataloader.get_rel_tot(), dim=1024, p_norm=1, norm_flag=False, margin=6.0) # define the loss function model = NegativeSampling(model=transe, loss=SigmoidLoss(adv_temperature=1), batch_size=train_dataloader.get_batch_size(), regul_rate=0.0) # train the model trainer = Trainer(model=model, data_loader=train_dataloader, train_times=3000, alpha=2e-5, use_gpu=False,
train_dataloader = TrainDataLoader(in_path="./benchmarks/LUMB/", nbatches=100, threads=8, sampling_mode="normal", bern_flag=1, filter_flag=1, neg_ent=25, neg_rel=0) # dataloader for test #test_dataloader = TestDataLoader("./benchmarks/LUMB/", "link") # define the model transe = TransE(ent_tot=train_dataloader.get_ent_tot(), rel_tot=train_dataloader.get_rel_tot(), dim=1, p_norm=1, norm_flag=False) # define the loss function model = NegativeSampling(model=transe, loss=MarginLoss(margin=5.0), batch_size=train_dataloader.get_batch_size()) # train the model trainer = Trainer(model=model, data_loader=train_dataloader, train_times=100, alpha=1.0, use_gpu=False) trainer.run()
not_found = 0 for i in range(max_id): entity = id2entity[i] word = entity2name[entity] try: weights_matrix[i] = glove[word] except KeyError: weights_matrix[i] = glove['unk'] not_found += 1 # define the model transe = TransE( ent_tot=train_dataloader.get_ent_tot(), rel_tot=train_dataloader.get_rel_tot(), ent_weight=weights_matrix, # rel_weight = cur_rel_weight, dim=200, p_norm=1, norm_flag=True) # define the loss function model = NegativeSampling(model=transe, loss=MarginLoss(margin=5.0), batch_size=train_dataloader.get_batch_size()) # train the model trainer = Trainer(model=model, data_loader=train_dataloader, train_times=1000, alpha=1.0, use_gpu=True)
if result_code == 1: result = "FAKE NEWS" break else: result_true_count += 1 if result is None: if result_true_count >= len(triples) / 2: result = "TRUE NEWS" else: result = "I'M NOT SURE" return result transe = TransE( ent_tot = len(entity_map.keys()), rel_tot = len(relation_map.keys()), dim = 1024, p_norm = 1, norm_flag = False, margin = 6.0 ) transe.load_checkpoint('./checkpoint/transe_fn.ckpt') tester = Tester(model = transe, use_gpu = False) number_of_test_example = 100 db = Mongo().get_client() print("Predicting random entity and relation ...") result_number = 0 news_list = db['covid_news_data'].find({ "status": 2 }) if news_list:
from openke.module.loss import MarginLoss from openke.module.strategy import NegativeSampling from openke.data import TrainDataLoader, TestDataLoader # dataloader for training train_dataloader = TrainDataLoader( in_path = "../OpenKEfiles/DBpedia/Restricted/", nbatches = 1000, threads = 8, sampling_mode = "normal", bern_flag = 1, filter_flag = 1, neg_ent = 25, neg_rel = 0) # dataloader for test test_dataloader = TestDataLoader("../OpenKEfiles/DBpedia/Restricted/", "link", type_constrain =False) # define the model transe = TransE( ent_tot = train_dataloader.get_ent_tot(), rel_tot = train_dataloader.get_rel_tot(), dim = 200, p_norm = 1, norm_flag = True) # test the model transe.load_checkpoint('../checkpoint/dbpedia/restricted/transe.ckpt') print("loaded checkpoint") sava_path = "../checkpoint/dbpedia/restricted/transe.embedding.vec.json" transe.save_parameters(sava_path)
# dataloader for training train_dataloader = TrainDataLoader( #in_path = "./benchmarks/transe_ske/", in_path=base_path, nbatches=100, threads=8, sampling_mode="normal", bern_flag=1, filter_flag=1, neg_ent=25, neg_rel=5) # define the model transe = TransE(ent_tot=train_dataloader.get_ent_tot(), rel_tot=train_dataloader.get_rel_tot(), dim=200, p_norm=2, norm_flag=True) # define the loss function model = NegativeSampling(model=transe, loss=MarginLoss(margin=5.0), batch_size=train_dataloader.get_batch_size()) # train the model trainer = Trainer(model=model, data_loader=train_dataloader, train_times=1000, alpha=1.0, use_gpu=True) trainer.run()
class QuestionAnswerModel(torch.nn.Module): def __init__(self, embed_model_path, bert_path, bert_name, n_clusters, embed_method='rotatE', fine_tune=True, attention=True, use_lstm=False, use_dnn=True, attention_method='mine', num_layers=2, bidirectional=False): super(QuestionAnswerModel, self).__init__() self.embed_method = embed_method self.device = 'cuda' if torch.cuda.is_available() else 'cpu' logger.info('using device: {}'.format(self.device)) self.relation_predictor = RelationPredictor( bert_path=bert_path, bert_name=bert_name, fine_tune=fine_tune, attention=attention, use_lstm=use_lstm, use_dnn=use_dnn, attention_method=attention_method, num_layers=num_layers, bidirectional=bidirectional).to(self.device) if self.embed_method == 'rotatE': self.score_func = self.rotatE self.KG_embed = RotatE(ent_tot=43234, rel_tot=18, dim=256, margin=6.0, epsilon=2.0) elif self.embed_method == 'complEx': self.score_func = self.complEx self.KG_embed = ComplEx(ent_tot=43234, rel_tot=18, dim=200) elif self.embed_method == 'DistMult': self.score_func = self.DistMult self.KG_embed = DistMult(ent_tot=43234, rel_tot=18, dim=200) elif self.embed_method == 'TransE': self.score_func = self.TransE self.KG_embed = TransE(ent_tot=43234, rel_tot=18, dim=200, p_norm=1, norm_flag=True) else: raise Exception('embed method not specified!') self.embed_model_path = embed_model_path self.KG_embed.load_checkpoint(self.embed_model_path) self.KG_embed.to(self.device) for param in self.KG_embed.parameters(): param.requires_grad = False logger.info('loading pretrained KG embedding from {}'.format( self.embed_model_path)) if self.embed_method == 'rotatE': self.cluster = KMeans(n_clusters=n_clusters) self.cluster2ent = [[] for _ in range(n_clusters)] for idx, label in enumerate( self.cluster.fit_predict( self.KG_embed.ent_embeddings.weight.cpu())): self.cluster2ent[label].append(idx) self.candidate_generator = CandidateGenerator( './MetaQA/KGE_data/train2id.txt') # cnt = 0 # for _ in self.cluster2ent: # cnt += len(_) # assert cnt == self.KG_embed.ent_tot def _to_tensor(self, inputs): return torch.tensor(inputs).to(self.device) def complEx(self, head, relation, tail): """ return torch.sum( h_re * t_re * r_re + h_im * t_im * r_re + h_re * t_im * r_im - h_im * t_re * r_im, -1 ) :param head: :param relation: :param tail: :return: """ batch_size = head.shape[1] target_size = tail.shape[1] # print(batch_size, target_size) re_head, im_head = torch.chunk(head.squeeze(2), 2, dim=0) re_tail, im_tail = torch.chunk(tail.squeeze(2), 2, dim=0) re_relation, im_relation = torch.chunk(relation.squeeze(2), 2, dim=0) # 统一转换成(batch_size, target_size, embed_size) # print(re_head.shape, re_tail.shape, re_relation.shape) re_head = re_head.expand(target_size, -1, -1).permute(1, 0, 2) im_head = im_head.expand(target_size, -1, -1).permute(1, 0, 2) re_tail = re_tail.expand(batch_size, -1, -1) im_tail = im_tail.expand(batch_size, -1, -1) im_relation = im_relation.expand(target_size, -1, -1).permute(1, 0, 2) re_relation = re_relation.expand(target_size, -1, -1).permute(1, 0, 2) score = torch.sum( re_head * re_tail * re_relation + im_head * im_tail * re_relation + re_head * im_tail * im_relation - im_head * re_tail * im_relation, -1) # (batch_size, target_size) # print(score.shape) return score def TransE(self, head, relation, tail): batch_size = head.shape[0] target_size = tail.shape[0] if self.KG_embed.norm_flag: head = F.normalize(head, 2, -1) relation = F.normalize(relation, 2, -1) tail = F.normalize(tail, 2, -1) # print(head.shape, tail.shape) head = head.unsqueeze(0).expand(target_size, -1, -1).permute(1, 0, 2) relation = relation.unsqueeze(0).expand(target_size, -1, -1).permute(1, 0, 2) tail = tail.unsqueeze(0).expand(batch_size, -1, -1) # print(head.shape, tail.shape) score = head + relation - tail score = torch.norm(score, self.KG_embed.p_norm, -1) # print(score.shape) return -score def DistMult(self, head, relation, tail): batch_size = head.shape[0] target_size = tail.shape[0] head = head.unsqueeze(0).expand(target_size, -1, -1).permute(1, 0, 2) relation = relation.unsqueeze(0).expand(target_size, -1, -1).permute(1, 0, 2) tail = tail.unsqueeze(0).expand(batch_size, -1, -1) score = (head * relation) * tail score = torch.sum(score, dim=-1) # print(score.shape) return score def rotatE(self, head, relation, tail): """ :param head: (batch_size, entity_embed) :param relation: (batch_size, relation_embed) :param tail: (target_size, entity_embed) :return: scores (batch_size, num_entity) """ pi = self.KG_embed.pi_const batch_size = head.shape[0] target_size = tail.shape[0] re_head, im_head = torch.chunk(head, 2, dim=-1) re_tail, im_tail = torch.chunk(tail, 2, dim=-1) regularized_relation = relation / ( self.KG_embed.rel_embedding_range.item() / pi) re_relation = torch.cos(regularized_relation) im_relation = torch.sin(regularized_relation) # (batch_size, ent_tot, entity_embed) re_head = re_head.unsqueeze(0).expand(target_size, -1, -1).permute(1, 0, 2) im_head = im_head.unsqueeze(0).expand(target_size, -1, -1).permute(1, 0, 2) re_tail = re_tail.unsqueeze(0).expand(batch_size, -1, -1) im_tail = im_tail.unsqueeze(0).expand(batch_size, -1, -1) im_relation = im_relation.unsqueeze(0).expand(target_size, -1, -1).permute(1, 0, 2) re_relation = re_relation.unsqueeze(0).expand(target_size, -1, -1).permute(1, 0, 2) re_score = re_head * re_relation - im_head * im_relation im_score = re_head * im_relation + im_head * re_relation re_score = re_score - re_tail im_score = im_score - im_tail # stack: 增加一维对两个tensor进行堆叠,相当于升维 score = torch.stack([re_score, im_score], dim=0) score = score.norm(dim=0).sum(dim=-1) # (batch_size, ent_tot) return self.KG_embed.margin - score def encode_question(self, question_token_ids, question_masks): return self.relation_predictor.encode_question_for_caching( self._to_tensor(question_token_ids), self._to_tensor(question_masks)) def predict(self, question_token_ids, question_masks, head_id): scores = self.forward(question_token_ids, question_masks, head_id) predicts = torch.sort(scores.cpu(), dim=1, descending=True).indices # print(predicts.shape) return predicts # 经实验 sigmoid效果最好 def forward(self, question_token_ids, question_masks, head_id, last_hidden_states=None, use_cluster=False): if last_hidden_states is None: rel_scores = self.relation_predictor( self._to_tensor(question_token_ids), self._to_tensor(question_masks), None) else: rel_scores = self.relation_predictor( None, None, self._to_tensor(last_hidden_states)) _index = [_[0] for _ in head_id] # print(_index) adjacency_scores = torch.index_select( self._to_tensor(self.relation_predictor.adjacencyMatrix), 0, self._to_tensor(_index)) adjacency_scores = self.relation_predictor.adjacencyHandler( adjacency_scores) rel_scores = (rel_scores + adjacency_scores) / 2 # print(adjacency_scores) # relation的预测方式采用self.KG_embed.rel_embeddings.weight的线性组合,取sigmoid(scores)作为组合系数 # print(predict_relation) # predict_relation = torch.clip(predict_relation, # min=-self.KG_embed.rel_embedding_range.weight.data, # max=self.KG_embed.rel_embedding_range.weight.data) # print(predict_relation) # predict_relation = self.relation_predictor(self._to_tensor(question_token_ids), # self._to_tensor(question_masks)) if self.embed_method == 'complEx': _tensor = self._to_tensor(head_id) head_embed = torch.stack([ self.KG_embed.ent_re_embeddings(_tensor), self.KG_embed.ent_im_embeddings(_tensor) ], dim=0) predict_relation = torch.matmul( torch.sigmoid(rel_scores), torch.stack([ self.KG_embed.rel_re_embeddings.weight, self.KG_embed.rel_im_embeddings.weight ], dim=0)) else: head_embed = self.KG_embed.ent_embeddings( self._to_tensor(head_id)).squeeze(1) predict_relation = torch.matmul( torch.sigmoid(rel_scores), self.KG_embed.rel_embeddings.weight) # candidate_answers = list(self.candidate_generator.get_candidates(_index)) # print(_index, candidate_answers) indices = None if not use_cluster: if self.embed_method == 'complEx': tail_embed = torch.stack([ self.KG_embed.ent_re_embeddings.weight, self.KG_embed.ent_im_embeddings.weight ], dim=0) else: tail_embed = self.KG_embed.ent_embeddings.weight # scores越大越好 scores = self.score_func(head_embed, predict_relation, tail_embed) return scores else: centers = self.cluster.cluster_centers_ cluster_scores = self.score_func(head_embed, predict_relation, self._to_tensor(centers)) # print(cluster_scores) values, indices = torch.max(cluster_scores, dim=1) # print(values, indices) tail_embed = [] for cluster_index in indices: tail_embed.append( torch.index_select( self.KG_embed.ent_embeddings.weight, 0, self._to_tensor(self.cluster2ent[cluster_index]))) scores = [] for _head, _rel, _tail in zip(head_embed, predict_relation, tail_embed): scores.append( self.score_func(_head.unsqueeze(0), _rel.unsqueeze(0), _tail)) # print(scores) return scores, indices
nbatches=150, threads=8, sampling_mode="normal", bern_flag=1, filter_flag=1, neg_ent=25, neg_rel=0) # dataloader for test test_dataloader = TestDataLoader(in_path="./benchmarks/OMKG/", sampling_mode='link') # define the model transe = TransE(ent_tot=train_dataloader.get_ent_tot(), rel_tot=train_dataloader.get_rel_tot(), dim=50, p_norm=1, norm_flag=True) model_e = NegativeSampling(model=transe, loss=MarginLoss(margin=4.0), batch_size=train_dataloader.get_batch_size()) transr = TransR(ent_tot=train_dataloader.get_ent_tot(), rel_tot=train_dataloader.get_rel_tot(), dim_e=50, dim_r=100, p_norm=1, norm_flag=True, rand_init=False)
def __init__(self, embed_model_path, bert_path, bert_name, n_clusters, embed_method='rotatE', fine_tune=True, attention=True, use_lstm=False, use_dnn=True, attention_method='mine', num_layers=2, bidirectional=False): super(QuestionAnswerModel, self).__init__() self.embed_method = embed_method self.device = 'cuda' if torch.cuda.is_available() else 'cpu' logger.info('using device: {}'.format(self.device)) self.relation_predictor = RelationPredictor( bert_path=bert_path, bert_name=bert_name, fine_tune=fine_tune, attention=attention, use_lstm=use_lstm, use_dnn=use_dnn, attention_method=attention_method, num_layers=num_layers, bidirectional=bidirectional).to(self.device) if self.embed_method == 'rotatE': self.score_func = self.rotatE self.KG_embed = RotatE(ent_tot=43234, rel_tot=18, dim=256, margin=6.0, epsilon=2.0) elif self.embed_method == 'complEx': self.score_func = self.complEx self.KG_embed = ComplEx(ent_tot=43234, rel_tot=18, dim=200) elif self.embed_method == 'DistMult': self.score_func = self.DistMult self.KG_embed = DistMult(ent_tot=43234, rel_tot=18, dim=200) elif self.embed_method == 'TransE': self.score_func = self.TransE self.KG_embed = TransE(ent_tot=43234, rel_tot=18, dim=200, p_norm=1, norm_flag=True) else: raise Exception('embed method not specified!') self.embed_model_path = embed_model_path self.KG_embed.load_checkpoint(self.embed_model_path) self.KG_embed.to(self.device) for param in self.KG_embed.parameters(): param.requires_grad = False logger.info('loading pretrained KG embedding from {}'.format( self.embed_model_path)) if self.embed_method == 'rotatE': self.cluster = KMeans(n_clusters=n_clusters) self.cluster2ent = [[] for _ in range(n_clusters)] for idx, label in enumerate( self.cluster.fit_predict( self.KG_embed.ent_embeddings.weight.cpu())): self.cluster2ent[label].append(idx) self.candidate_generator = CandidateGenerator( './MetaQA/KGE_data/train2id.txt')
train_dataloader = TrainDataLoader(in_path="./benchmarks/LUMB/", nbatches=100, threads=8, sampling_mode="normal", bern_flag=1, filter_flag=1, neg_ent=25, neg_rel=0) # dataloader for test test_dataloader = TestDataLoader("./benchmarks/LUMB/", "link") # define the model transe = TransE(ent_tot=train_dataloader.get_ent_tot(), rel_tot=train_dataloader.get_rel_tot(), dim=200, p_norm=1, norm_flag=True) # define the loss function model = NegativeSampling(model=transe, loss=MarginLoss(margin=5.0), batch_size=train_dataloader.get_batch_size()) # train the model trainer = Trainer(model=model, data_loader=train_dataloader, train_times=100, alpha=1.0, use_gpu=False) trainer.run()
train_dataloader = TrainDataLoader(in_path=data_dir, nbatches=nbatches, threads=8, sampling_mode="cross", bern_flag=1, filter_flag=1, neg_ent=negative_samples, neg_rel=0) # dataloader for test test_dataloader = TestDataLoader(data_dir, "triple") # define the model transe = TransE(ent_tot=train_dataloader.get_ent_tot(), rel_tot=train_dataloader.get_rel_tot(), dim=embed_dim, p_norm=2, norm_flag=True) # define the loss function model = NegativeSampling(model=transe, loss=MarginLoss(margin=margin), batch_size=train_dataloader.get_batch_size()) # train the model trainer = Trainer(model = model, data_loader = train_dataloader, opt_method = "adam", train_times = train_times, \ alpha = alpha, use_gpu = True, checkpoint_dir=ckpt_path, save_steps=100) tester = Tester(model=transe, data_loader=test_dataloader, use_gpu=True) trainer.run(tester, test_every=100) print("Saving model to {0}...".format(ckpt_path)) transe.save_checkpoint(ckpt_path)