def load_model(self, model): num_entities = 1861 num_relations = 113 if self.args.model == 'TransGEN': self.model = TransGEN( self.embedding_size, self.embedding_size, num_entities, num_relations, args=self.args, entity_embedding=self.pretrain_entity_embedding, relation_embedding=self.pretrain_relation_embedding) else: raise ValueError("Model Name <{}> is Wrong".format( self.args.model)) meta_task_entity = torch.LongTensor(self.meta_task_entity) self.model.entity_embedding.weight.data[ meta_task_entity] = torch.zeros(len(meta_task_entity), self.embedding_size)
class Trainer(object): def __init__(self, args): super(Trainer, self).__init__() self.args = args self.exp_name = self.experiment_name(args) self.best_mrr = 0 self.use_cuda = args.gpu >= 0 and torch.cuda.is_available() if self.use_cuda: torch.cuda.set_device(args.gpu) self.filtered_triplets, self.meta_train_task_triplets, self.meta_valid_task_triplets, self.meta_test_task_triplets, \ self.meta_train_task_entity_to_triplets, self.meta_valid_task_entity_to_triplets, self.meta_test_task_entity_to_triplets \ = utils.load_processed_data('./Dataset/processed_data/{}'.format(args.data)) self.meta_task_entity = np.concatenate( (list(self.meta_train_task_entity_to_triplets.keys()), list(self.meta_valid_task_entity_to_triplets.keys()), list(self.meta_test_task_entity_to_triplets.keys()))) self.meta_task_test_entity = torch.LongTensor( np.array(list(self.meta_test_task_entity_to_triplets.keys()))) self.load_pretrain_embedding(data=args.data, model=args.pre_train_model) self.load_model(model=args.model) if self.use_cuda: self.model.cuda() self.meta_task_test_entity = self.meta_task_test_entity.cuda() self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay) def load_pretrain_embedding(self, data, model): self.embedding_size = int(self.args.pre_train_emb_size) if self.args.pre_train: entity_file_name = './Pretraining/{}/{}_entity.npy'.format( self.args.data, self.args.pre_train_model) self.pretrain_entity_embedding = torch.Tensor( np.load(entity_file_name)) self.pretrain_relation_embedding = None else: self.pretrain_entity_embedding = None self.pretrain_relation_embedding = None def load_model(self, model): num_entities = 1861 num_relations = 113 if self.args.model == 'TransGEN': self.model = TransGEN( self.embedding_size, self.embedding_size, num_entities, num_relations, args=self.args, entity_embedding=self.pretrain_entity_embedding, relation_embedding=self.pretrain_relation_embedding) else: raise ValueError("Model Name <{}> is Wrong".format( self.args.model)) meta_task_entity = torch.LongTensor(self.meta_task_entity) self.model.entity_embedding.weight.data[ meta_task_entity] = torch.zeros(len(meta_task_entity), self.embedding_size) def train(self): checkpoint = torch.load('{}/best_mrr_model.pth'.format(self.exp_name), map_location='cuda:{}'.format(args.gpu)) self.model.load_state_dict(checkpoint['state_dict']) print("Using best epoch: {}, {}".format(checkpoint['epoch'], self.exp_name)) # Test Code eval_types = {'normal': True, 'mc_score': True} results = {} if eval_types['normal']: tqdm.write("Results about Normal (Mean) Inference") total_results, total_induc_results, total_trans_results = self.eval( eval_type='test') results['total_prs'] = total_results['pr'] results['total_rocs'] = total_results['roc'] results['total_p@1s'] = total_results['acc'] results['total_induc_prs'] = total_induc_results['pr'] results['total_induc_rocs'] = total_induc_results['roc'] results['total_induc_p@1s'] = total_induc_results['acc'] results['total_trans_prs'] = total_trans_results['pr'] results['total_trans_rocs'] = total_trans_results['roc'] results['total_trans_p@1s'] = total_trans_results['acc'] tqdm.write("Total PR (filtered): {:.6f}".format( results['total_prs'])) tqdm.write("Total ROC (filtered) {:.6f}".format( results['total_rocs'])) tqdm.write("Total P@1s (filtered) {:.6f}".format( results['total_p@1s'])) tqdm.write("Total Induc PR (filtered): {:.6f}".format( results['total_induc_prs'])) tqdm.write("Total Induc ROC (filtered) {:.6f}".format( results['total_induc_rocs'])) tqdm.write("Total Induc P@1s (filtered) {:.6f}".format( results['total_induc_p@1s'])) tqdm.write("Total Trans PR (filtered): {:.6f}".format( results['total_trans_prs'])) tqdm.write("Total Trans ROC (filtered) {:.6f}".format( results['total_trans_rocs'])) tqdm.write("Total Trans P@1s (filtered) {:.6f}".format( results['total_trans_p@1s'])) if eval_types['mc_score']: tqdm.write("Results about MC score inference") total_results, total_induc_results, total_trans_results = self.mc_score_inference( eval_type='test') results['total_prs'] = total_results['pr'] results['total_rocs'] = total_results['roc'] results['total_p@1s'] = total_results['acc'] results['total_induc_prs'] = total_induc_results['pr'] results['total_induc_rocs'] = total_induc_results['roc'] results['total_induc_p@1s'] = total_induc_results['acc'] results['total_trans_prs'] = total_trans_results['pr'] results['total_trans_rocs'] = total_trans_results['roc'] results['total_trans_p@1s'] = total_trans_results['acc'] tqdm.write("Total PR (filtered): {:.6f}".format( results['total_prs'])) tqdm.write("Total ROC (filtered) {:.6f}".format( results['total_rocs'])) tqdm.write("Total P@1s (filtered) {:.6f}".format( results['total_p@1s'])) tqdm.write("Total Induc PR (filtered): {:.6f}".format( results['total_induc_prs'])) tqdm.write("Total Induc ROC (filtered) {:.6f}".format( results['total_induc_rocs'])) tqdm.write("Total Induc P@1s (filtered) {:.6f}".format( results['total_induc_p@1s'])) tqdm.write("Total Trans PR (filtered): {:.6f}".format( results['total_trans_prs'])) tqdm.write("Total Trans ROC (filtered) {:.6f}".format( results['total_trans_rocs'])) tqdm.write("Total Trans P@1s (filtered) {:.6f}".format( results['total_trans_p@1s'])) def eval(self, eval_type='test'): self.model.eval() if eval_type == 'valid': test_task_dict = self.meta_valid_task_entity_to_triplets test_task_pool = list( self.meta_valid_task_entity_to_triplets.keys()) elif eval_type == 'test': test_task_dict = self.meta_test_task_entity_to_triplets test_task_pool = list( self.meta_test_task_entity_to_triplets.keys()) else: raise ValueError("Eval Type <{}> is Wrong".format(eval_type)) total_task_entity = [] total_task_entity_embedding = [] total_train_task_triplets = [] total_test_task_triplets = [] total_test_task_triplets_dict = dict() for task_entity in tqdm(test_task_pool): task_triplets = test_task_dict[task_entity] task_triplets = np.array(task_triplets) task_heads, task_relations, task_tails = task_triplets.transpose() train_task_triplets = task_triplets[:self.args.few] test_task_triplets = task_triplets[self.args.few:] if (len(task_triplets)) - self.args.few < 1: continue # Train (Inductive) task_entity_embedding = self.model(task_entity, train_task_triplets, use_cuda=self.use_cuda, is_trans=False) total_task_entity.append(task_entity) total_task_entity_embedding.append(task_entity_embedding) total_train_task_triplets.extend(train_task_triplets) total_test_task_triplets.extend(test_task_triplets) total_test_task_triplets_dict[task_entity] = torch.LongTensor( test_task_triplets) # Train (Transductive) total_task_entity = np.array(total_task_entity) total_task_entity_embedding = torch.cat( total_task_entity_embedding).view(-1, self.embedding_size) total_train_task_triplets = np.array(total_train_task_triplets) total_test_task_triplets = torch.LongTensor(total_test_task_triplets) task_entity_embeddings, _, _ = self.model( total_task_entity, total_train_task_triplets, use_cuda=self.use_cuda, is_trans=True, total_unseen_entity_embedding=total_task_entity_embedding) # Test total_task_entity = torch.from_numpy(total_task_entity) if self.use_cuda: total_task_entity = total_task_entity.cuda() my_total_triplets = [] my_induc_triplets = [] my_trans_triplets = [] for task_entity, test_triplets in total_test_task_triplets_dict.items( ): if self.use_cuda: device = torch.device('cuda') test_triplets = test_triplets.cuda() for test_triplet in test_triplets: is_trans = self.is_trans(total_task_entity, test_triplet) my_total_triplets.append(test_triplet) if is_trans: my_trans_triplets.append(test_triplet) else: my_induc_triplets.append(test_triplet) my_total_triplets = torch.stack(my_total_triplets, dim=0) y_prob, y = self.model.predict(total_task_entity, task_entity_embeddings, my_total_triplets, target=None, use_cuda=self.use_cuda) y_prob = y_prob.detach().cpu().numpy() y = y.detach().cpu().numpy() total_results = utils.metric_report(y, y_prob) my_induc_triplets = torch.stack(my_induc_triplets, dim=0) y_prob, y = self.model.predict(total_task_entity, task_entity_embeddings, my_induc_triplets, target=None, use_cuda=self.use_cuda) y_prob = y_prob.detach().cpu().numpy() y = y.detach().cpu().numpy() total_induc_results = utils.metric_report(y, y_prob) my_trans_triplets = torch.stack(my_trans_triplets, dim=0) y_prob, y = self.model.predict(total_task_entity, task_entity_embeddings, my_trans_triplets, target=None, use_cuda=self.use_cuda) y_prob = y_prob.detach().cpu().numpy() y = y.detach().cpu().numpy() total_trans_results = utils.metric_report(y, y_prob) return total_results, total_induc_results, total_trans_results def mc_score_inference(self, eval_type='test'): self.model.eval() if eval_type == 'valid': test_task_dict = self.meta_valid_task_entity_to_triplets test_task_pool = list( self.meta_valid_task_entity_to_triplets.keys()) elif eval_type == 'test': test_task_dict = self.meta_test_task_entity_to_triplets test_task_pool = list( self.meta_test_task_entity_to_triplets.keys()) else: raise ValueError("Eval Type <{}> is Wrong".format(eval_type)) total_task_entity = [] total_task_entity_embeddings = [] total_train_task_triplets = [] total_test_task_triplets = [] total_test_task_triplets_dict = dict() for task_entity in tqdm(test_task_pool): task_triplets = test_task_dict[task_entity] task_triplets = np.array(task_triplets) task_heads, task_relations, task_tails = task_triplets.transpose() train_task_triplets = task_triplets[:self.args.few] test_task_triplets = task_triplets[self.args.few:] if (len(task_triplets)) - self.args.few < 1: continue # Train (Inductive) task_entity_embedding = torch.cat([ self.model(task_entity, train_task_triplets, use_cuda=self.use_cuda, is_trans=False) for _ in range(self.args.mc_times) ]).view(-1, self.embedding_size) total_task_entity.append(task_entity) total_task_entity_embeddings.append(task_entity_embedding) total_train_task_triplets.extend(train_task_triplets) total_test_task_triplets.extend(test_task_triplets) total_test_task_triplets_dict[task_entity] = torch.LongTensor( test_task_triplets) # Train (Transductive) total_task_entity = np.array(total_task_entity) total_task_entity_embeddings = torch.cat( total_task_entity_embeddings).view(-1, self.args.mc_times, self.embedding_size) total_train_task_triplets = np.array(total_train_task_triplets) total_test_task_triplets = torch.LongTensor(total_test_task_triplets) self.model.train() task_entity_embeddings = torch.cat([ self.model( total_task_entity, total_train_task_triplets, use_cuda=self.use_cuda, is_trans=True, total_unseen_entity_embedding=total_task_entity_embeddings[:, i]) [0] for i in range(self.args.mc_times) ]).view(self.args.mc_times, -1, self.embedding_size) # Test total_task_entity = torch.from_numpy(total_task_entity) if self.use_cuda: total_task_entity = total_task_entity.cuda() my_total_triplets = [] my_induc_triplets = [] my_trans_triplets = [] for task_entity, test_triplets in total_test_task_triplets_dict.items( ): if self.use_cuda: device = torch.device('cuda') test_triplets = test_triplets.cuda() for test_triplet in test_triplets: is_trans = self.is_trans(total_task_entity, test_triplet) my_total_triplets.append(test_triplet) if is_trans: my_trans_triplets.append(test_triplet) else: my_induc_triplets.append(test_triplet) my_total_triplets = torch.stack(my_total_triplets, dim=0) for mc_index in range(self.args.mc_times): y_prob, y = self.model.predict(total_task_entity, task_entity_embeddings[mc_index], my_total_triplets, target=None, use_cuda=self.use_cuda) if mc_index == 0: y_prob_mean = y_prob else: y_prob_mean += y_prob y_prob_mean = y_prob_mean / self.args.mc_times y_prob = y_prob_mean.detach().cpu().numpy() y = y.detach().cpu().numpy() total_results = utils.metric_report(y, y_prob) my_induc_triplets = torch.stack(my_induc_triplets, dim=0) for mc_index in range(self.args.mc_times): y_prob, y = self.model.predict(total_task_entity, task_entity_embeddings[mc_index], my_induc_triplets, target=None, use_cuda=self.use_cuda) if mc_index == 0: y_prob_mean = y_prob else: y_prob_mean += y_prob y_prob_mean = y_prob_mean / self.args.mc_times y_prob = y_prob_mean.detach().cpu().numpy() y = y.detach().cpu().numpy() total_induc_results = utils.metric_report(y, y_prob) my_trans_triplets = torch.stack(my_trans_triplets, dim=0) for mc_index in range(self.args.mc_times): y_prob, y = self.model.predict(total_task_entity, task_entity_embeddings[mc_index], my_trans_triplets, target=None, use_cuda=self.use_cuda) if mc_index == 0: y_prob_mean = y_prob else: y_prob_mean += y_prob y_prob_mean = y_prob_mean / self.args.mc_times y_prob = y_prob_mean.detach().cpu().numpy() y = y.detach().cpu().numpy() total_trans_results = utils.metric_report(y, y_prob) return total_results, total_induc_results, total_trans_results def is_trans(self, total_task_entity, test_triplet): is_trans = False if (test_triplet[0] in total_task_entity) and (test_triplet[2] in total_task_entity): is_trans = True return is_trans def experiment_name(self, args): exp_name = os.path.join('./checkpoints', self.args.exp_name) return exp_name
class Trainer(object): def __init__(self, args): super(Trainer, self).__init__() self.args = args self.exp_name = self.experiment_name(args) self.best_mrr = 0 self.use_cuda = args.gpu >= 0 and torch.cuda.is_available() if self.use_cuda: torch.cuda.set_device(args.gpu) self.entity2id, self.relation2id, self.train_triplets, self.valid_triplets, self.test_triplets = utils.load_data( './Dataset/raw_data/{}'.format(args.data)) self.filtered_triplets, self.meta_train_task_triplets, self.meta_valid_task_triplets, self.meta_test_task_triplets, \ self.meta_train_task_entity_to_triplets, self.meta_valid_task_entity_to_triplets, self.meta_test_task_entity_to_triplets \ = utils.load_processed_data('./Dataset/processed_data/{}'.format(args.data)) self.all_triplets = torch.LongTensor( np.concatenate((self.train_triplets, self.valid_triplets, self.test_triplets))) self.meta_task_entity = np.concatenate( (list(self.meta_train_task_entity_to_triplets.keys()), list(self.meta_valid_task_entity_to_triplets.keys()), list(self.meta_test_task_entity_to_triplets.keys()))) self.meta_task_triplets = torch.LongTensor( np.concatenate( (self.meta_train_task_triplets, self.meta_valid_task_triplets, self.meta_test_task_triplets))) self.meta_task_test_entity = torch.LongTensor( np.array(list(self.meta_test_task_entity_to_triplets.keys()))) self.load_pretrain_embedding(data=args.data, model=args.pre_train_model) self.load_model(model=args.model) if self.use_cuda: self.model.cuda() self.all_triplets = self.all_triplets.cuda() self.meta_task_triplets = self.meta_task_triplets.cuda() self.meta_task_test_entity = self.meta_task_test_entity.cuda() self.head_relation_triplets = self.all_triplets[:, :2] self.tail_relation_triplets = torch.stack( (self.all_triplets[:, 2], self.all_triplets[:, 1])).transpose(0, 1) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay) def load_pretrain_embedding(self, data, model): self.embedding_size = int(self.args.pre_train_emb_size) if self.args.pre_train: pretrain_model_path = './Pretraining/{}'.format(self.args.data) entity_file_name = os.path.join( pretrain_model_path, '{}_entity_{}.npy'.format(self.args.pre_train_model, self.embedding_size)) relation_file_name = os.path.join( pretrain_model_path, '{}_relation_{}.npy'.format(self.args.pre_train_model, self.embedding_size)) self.pretrain_entity_embedding = torch.Tensor( np.load(entity_file_name)) self.pretrain_relation_embedding = torch.Tensor( np.load(relation_file_name)) else: self.pretrain_entity_embedding = None self.pretrain_relation_embedding = None def load_model(self, model): if self.args.model == 'TransGEN': self.model = TransGEN( self.embedding_size, self.embedding_size, len(self.entity2id), len(self.relation2id), args=self.args, entity_embedding=self.pretrain_entity_embedding, relation_embedding=self.pretrain_relation_embedding) else: raise ValueError("Model Name <{}> is Wrong".format( self.args.model)) meta_task_entity = torch.LongTensor(self.meta_task_entity) self.model.entity_embedding.weight.data[ meta_task_entity] = torch.zeros(len(meta_task_entity), self.embedding_size) def train(self): checkpoint = torch.load('{}/best_mrr_model.pth'.format(self.exp_name), map_location='cuda:{}'.format(args.gpu)) self.model.load_state_dict(checkpoint['state_dict']) print("Using best epoch: {}, {}".format(checkpoint['epoch'], self.exp_name)) eval_types = {'normal': True, 'mc_score_inference': True} results = {} mc_inference_results = {} if eval_types['normal']: tqdm.write("Results about Normal (Mean) Inference") with torch.no_grad(): total_ranks, total_induc_ranks, total_trans_ranks = self.eval( eval_type='test') results['total_mrrs'] = torch.mean(1.0 / total_ranks.float()).item() for hit in [1, 3, 10]: avg_count = torch.mean((total_ranks <= hit).float()) results['total_hits@{}s'.format(hit)] = avg_count.item() results['total_induc_mrrs'] = torch.mean( 1.0 / total_induc_ranks.float()).item() for hit in [1, 3, 10]: avg_count = torch.mean((total_induc_ranks <= hit).float()) results['total_induc_hits@{}s'.format(hit)] = avg_count.item() results['total_trans_mrrs'] = torch.mean( 1.0 / total_trans_ranks.float()).item() for hit in [1, 3, 10]: avg_count = torch.mean((total_trans_ranks <= hit).float()) results['total_trans_hits@{}s'.format(hit)] = avg_count.item() tqdm.write("Total MRR (filtered): {:.6f}".format( results['total_mrrs'])) tqdm.write("Total Hits (filtered) @ {}: {:.6f}".format( 1, results['total_hits@1s'])) tqdm.write("Total Hits (filtered) @ {}: {:.6f}".format( 3, results['total_hits@3s'])) tqdm.write("Total Hits (filtered) @ {}: {:.6f}".format( 10, results['total_hits@10s'])) tqdm.write("Total Induc MRR (filtered): {:.6f}".format( results['total_induc_mrrs'])) tqdm.write("Total Induc Hits (filtered) @ {}: {:.6f}".format( 1, results['total_induc_hits@1s'])) tqdm.write("Total Induc Hits (filtered) @ {}: {:.6f}".format( 3, results['total_induc_hits@3s'])) tqdm.write("Total Induc Hits (filtered) @ {}: {:.6f}".format( 10, results['total_induc_hits@10s'])) tqdm.write("Total Trans MRR (filtered): {:.6f}".format( results['total_trans_mrrs'])) tqdm.write("Total Trans Hits (filtered) @ {}: {:.6f}".format( 1, results['total_trans_hits@1s'])) tqdm.write("Total Trans Hits (filtered) @ {}: {:.6f}".format( 3, results['total_trans_hits@3s'])) tqdm.write("Total Trans Hits (filtered) @ {}: {:.6f}".format( 10, results['total_trans_hits@10s'])) if eval_types['mc_score_inference']: tqdm.write("Results about MC score inference") with torch.no_grad(): total_ranks, total_induc_ranks, total_trans_ranks = self.mc_score_inference( eval_type='test', prob=False) mc_inference_results['total_mrrs'] = torch.mean( 1.0 / total_ranks.float()).item() for hit in [1, 3, 10]: avg_count = torch.mean((total_ranks <= hit).float()) mc_inference_results['total_hits@{}s'.format( hit)] = avg_count.item() mc_inference_results['total_induc_mrrs'] = torch.mean( 1.0 / total_induc_ranks.float()).item() for hit in [1, 3, 10]: avg_count = torch.mean((total_induc_ranks <= hit).float()) mc_inference_results['total_induc_hits@{}s'.format( hit)] = avg_count.item() mc_inference_results['total_trans_mrrs'] = torch.mean( 1.0 / total_trans_ranks.float()).item() for hit in [1, 3, 10]: avg_count = torch.mean((total_trans_ranks <= hit).float()) mc_inference_results['total_trans_hits@{}s'.format( hit)] = avg_count.item() tqdm.write("Total MRR (filtered): {:.6f}".format( mc_inference_results['total_mrrs'])) tqdm.write("Total Hits (filtered) @ {}: {:.6f}".format( 1, mc_inference_results['total_hits@1s'])) tqdm.write("Total Hits (filtered) @ {}: {:.6f}".format( 3, mc_inference_results['total_hits@3s'])) tqdm.write("Total Hits (filtered) @ {}: {:.6f}".format( 10, mc_inference_results['total_hits@10s'])) tqdm.write("Total Induc MRR (filtered): {:.6f}".format( mc_inference_results['total_induc_mrrs'])) tqdm.write("Total Induc Hits (filtered) @ {}: {:.6f}".format( 1, mc_inference_results['total_induc_hits@1s'])) tqdm.write("Total Induc Hits (filtered) @ {}: {:.6f}".format( 3, mc_inference_results['total_induc_hits@3s'])) tqdm.write("Total Induc Hits (filtered) @ {}: {:.6f}".format( 10, mc_inference_results['total_induc_hits@10s'])) tqdm.write("Total Trans MRR (filtered): {:.6f}".format( mc_inference_results['total_trans_mrrs'])) tqdm.write("Total Trans Hits (filtered) @ {}: {:.6f}".format( 1, mc_inference_results['total_trans_hits@1s'])) tqdm.write("Total Trans Hits (filtered) @ {}: {:.6f}".format( 3, mc_inference_results['total_trans_hits@3s'])) tqdm.write("Total Trans Hits (filtered) @ {}: {:.6f}".format( 10, mc_inference_results['total_trans_hits@10s'])) def eval(self, eval_type='test'): self.model.eval() if eval_type == 'valid': test_task_dict = self.meta_valid_task_entity_to_triplets test_task_pool = list( self.meta_valid_task_entity_to_triplets.keys()) elif eval_type == 'test': test_task_dict = self.meta_test_task_entity_to_triplets test_task_pool = list( self.meta_test_task_entity_to_triplets.keys()) else: raise ValueError("Eval Type <{}> is Wrong".format(eval_type)) total_ranks = [] subject_ranks = [] object_ranks = [] total_induc_ranks = [] subject_induc_ranks = [] object_induc_ranks = [] total_trans_ranks = [] subject_trans_ranks = [] object_trans_ranks = [] total_task_entity = [] total_task_entity_embedding = [] total_train_task_triplets = [] total_test_task_triplets = [] total_test_task_triplets_dict = dict() for task_entity in tqdm(test_task_pool): task_triplets = test_task_dict[task_entity] task_triplets = np.array(task_triplets) task_heads, task_relations, task_tails = task_triplets.transpose() train_task_triplets = task_triplets[:self.args.few] test_task_triplets = task_triplets[self.args.few:] if (len(task_triplets)) - self.args.few < 1: continue # Train (Inductive) task_entity_embedding = self.model(task_entity, train_task_triplets, use_cuda=self.use_cuda, is_trans=False) total_task_entity.append(task_entity) total_task_entity_embedding.append(task_entity_embedding) total_train_task_triplets.extend(train_task_triplets) total_test_task_triplets.extend(test_task_triplets) total_test_task_triplets_dict[task_entity] = torch.LongTensor( test_task_triplets) # Train (Transductive) total_task_entity = np.array(total_task_entity) total_task_entity_embedding = torch.cat( total_task_entity_embedding).view(-1, self.embedding_size) total_train_task_triplets = np.array(total_train_task_triplets) total_test_task_triplets = torch.LongTensor(total_test_task_triplets) task_entity_embeddings, _, _ = self.model( total_task_entity, total_train_task_triplets, use_cuda=self.use_cuda, is_trans=True, total_unseen_entity_embedding=total_task_entity_embedding) # Test total_task_entity = torch.from_numpy(total_task_entity) if self.use_cuda: total_task_entity = total_task_entity.cuda() all_entity_embeddings = copy.deepcopy( self.model.entity_embedding.weight).detach() all_relation_embeddings = copy.deepcopy( self.model.relation_embedding).detach() all_entity_embeddings[total_task_entity] = task_entity_embeddings for task_entity, test_triplets in total_test_task_triplets_dict.items( ): if self.use_cuda: device = torch.device('cuda') test_triplets = test_triplets.cuda() for test_triplet in test_triplets: rank, (is_trans, is_subject, is_object) = self.calc_rank( task_entity, total_task_entity, all_entity_embeddings[task_entity], all_entity_embeddings, all_relation_embeddings, test_triplet, self.all_triplets, use_cuda=self.use_cuda, score_function=self.args.score_function) rank += 1 rank = rank.cpu() total_ranks.append(rank) if is_subject: subject_ranks.append(rank) elif is_object: object_ranks.append(rank) if is_trans: total_trans_ranks.append(rank) if is_subject: subject_trans_ranks.append(rank) elif is_object: object_trans_ranks.append(rank) else: total_induc_ranks.append(rank) if is_subject: subject_induc_ranks.append(rank) elif is_object: object_induc_ranks.append(rank) total_ranks = torch.cat(total_ranks) total_induc_ranks = torch.cat(total_induc_ranks) total_trans_ranks = torch.cat(total_trans_ranks) return total_ranks, total_induc_ranks, total_trans_ranks def mc_score_inference(self, eval_type='test', prob=False): self.model.eval() if eval_type == 'valid': test_task_dict = self.meta_valid_task_entity_to_triplets test_task_pool = list( self.meta_valid_task_entity_to_triplets.keys()) elif eval_type == 'test': test_task_dict = self.meta_test_task_entity_to_triplets test_task_pool = list( self.meta_test_task_entity_to_triplets.keys()) else: raise ValueError("Eval Type <{}> is Wrong".format(eval_type)) total_scores = [] total_induc_scores = [] total_trans_scores = [] total_task_entity = [] total_task_entity_embeddings = [] total_train_task_triplets = [] total_test_task_triplets = [] total_test_task_triplets_dict = dict() for task_entity in tqdm(test_task_pool): task_triplets = test_task_dict[task_entity] task_triplets = np.array(task_triplets) task_heads, task_relations, task_tails = task_triplets.transpose() train_task_triplets = task_triplets[:self.args.few] test_task_triplets = task_triplets[self.args.few:] if (len(task_triplets)) - self.args.few < 1: continue # Train (Inductive) task_entity_embedding = torch.cat([ self.model(task_entity, train_task_triplets, use_cuda=self.use_cuda, is_trans=False) for _ in range(self.args.mc_times) ]).view(-1, self.embedding_size) total_task_entity.append(task_entity) total_task_entity_embeddings.append(task_entity_embedding) total_train_task_triplets.extend(train_task_triplets) total_test_task_triplets.extend(test_task_triplets) total_test_task_triplets_dict[task_entity] = torch.LongTensor( test_task_triplets) self.model.train() # Train (Transductive) total_task_entity = np.array(total_task_entity) total_task_entity_embeddings = torch.cat( total_task_entity_embeddings).view(-1, self.args.mc_times, self.embedding_size) total_train_task_triplets = np.array(total_train_task_triplets) total_test_task_triplets = torch.LongTensor(total_test_task_triplets) task_entity_embeddings = torch.cat([ self.model( total_task_entity, total_train_task_triplets, use_cuda=self.use_cuda, is_trans=True, total_unseen_entity_embedding=total_task_entity_embeddings[:, i]) [0] for i in range(self.args.mc_times) ]).view(self.args.mc_times, -1, self.embedding_size) # Test total_task_entity = torch.from_numpy(total_task_entity) if self.use_cuda: total_task_entity = total_task_entity.cuda() for mc_index in trange(self.args.mc_times): scores = [] induc_scores = [] trans_scores = [] all_entity_embeddings = copy.deepcopy( self.model.entity_embedding.weight).detach() all_relation_embeddings = copy.deepcopy( self.model.relation_embedding).detach() all_entity_embeddings[total_task_entity] = task_entity_embeddings[ mc_index, :] for task_entity, test_triplets in total_test_task_triplets_dict.items( ): if self.use_cuda: device = torch.device('cuda') test_triplets = test_triplets.cuda() for test_triplet in test_triplets: score, (is_trans, is_subject, is_object) = self.calc_score( task_entity, total_task_entity, all_entity_embeddings[task_entity], all_entity_embeddings, all_relation_embeddings, test_triplet, self.all_triplets, use_cuda=self.use_cuda, score_function=self.args.score_function) if prob == True: score = F.softmax(score, dim=1) score = score.to('cpu') scores.append(score) if is_trans: trans_scores.append(score) else: induc_scores.append(score) total_scores.append(scores) total_induc_scores.append(induc_scores) total_trans_scores.append(trans_scores) total_scores = list(map(list, zip(*total_scores))) total_induc_scores = list(map(list, zip(*total_induc_scores))) total_trans_scores = list(map(list, zip(*total_trans_scores))) total_ranks = [] total_induc_ranks = [] total_trans_ranks = [] target = torch.tensor(0) for i in range(len(total_scores)): scores = total_scores[i] scores = torch.cat(scores) score = torch.mean(scores, dim=0, keepdim=True) rank = utils.sort_and_rank(score, target) rank += 1 total_ranks.append(rank) for i in range(len(total_induc_scores)): scores = total_induc_scores[i] scores = torch.cat(scores) score = torch.mean(scores, dim=0, keepdim=True) rank = utils.sort_and_rank(score, target) rank += 1 total_induc_ranks.append(rank) for i in range(len(total_trans_scores)): scores = total_trans_scores[i] scores = torch.cat(scores) score = torch.mean(scores, dim=0, keepdim=True) rank = utils.sort_and_rank(score, target) rank += 1 total_trans_ranks.append(rank) total_ranks = torch.cat(total_ranks) total_induc_ranks = torch.cat(total_induc_ranks) total_trans_ranks = torch.cat(total_trans_ranks) return total_ranks, total_induc_ranks, total_trans_ranks def calc_rank(self, task_entity, total_task_entity, task_entity_embedding, entity_embeddings, relation_embeddings, test_triplet, all_triplets, use_cuda, score_function): num_entity = len(entity_embeddings) is_trans = False is_subject = False is_object = False if (test_triplet[0] in total_task_entity) and (test_triplet[2] in total_task_entity): is_trans = True if (test_triplet[0] == task_entity): is_subject = True subject = test_triplet[0] relation = test_triplet[1] object_ = test_triplet[2] subject_relation = torch.LongTensor([subject, relation]) if use_cuda: subject_relation = subject_relation.cuda() delete_index = torch.sum( self.head_relation_triplets == subject_relation, dim=1) delete_index = torch.nonzero(delete_index == 2).squeeze() if use_cuda: device = torch.device('cuda') delete_entity_index = all_triplets[delete_index, 2].view(-1).cpu().numpy() perturb_entity_index = np.array( list( set(np.arange(num_entity)) - set(delete_entity_index))) perturb_entity_index = torch.from_numpy( perturb_entity_index).to(device) perturb_entity_index = torch.cat( (object_.view(-1), perturb_entity_index)) else: delete_entity_index = all_triplets[delete_index, 2].view(-1).numpy() perturb_entity_index = np.array( list( set(np.arange(num_entity)) - set(delete_entity_index))) perturb_entity_index = torch.from_numpy(perturb_entity_index) perturb_entity_index = torch.cat( (object_.view(-1), perturb_entity_index)) # Score if score_function == 'DistMult': emb_ar = task_entity_embedding * relation_embeddings[relation] emb_ar = emb_ar.view(-1, 1, 1) emb_c = entity_embeddings[perturb_entity_index] emb_c = emb_c.transpose(0, 1).unsqueeze(1) out_prod = torch.bmm(emb_ar, emb_c) score = torch.sum(out_prod, dim=0) score = F.softmax(score, dim=1) elif score_function == 'TransE': head_embedding = task_entity_embedding relation_embedding = relation_embeddings[relation] tail_embeddings = entity_embeddings[perturb_entity_index] score = -torch.norm( (head_embedding + relation_embedding - tail_embeddings), p=2, dim=1) score = score.view(1, -1) score = F.softmax(score, dim=1) else: raise TypeError elif (test_triplet[2] == task_entity): is_object = True subject = test_triplet[0] relation = test_triplet[1] object_ = test_triplet[2] object_relation = torch.LongTensor([object_, relation]) if use_cuda: object_relation = object_relation.cuda() delete_index = torch.sum( self.tail_relation_triplets == object_relation, dim=1) delete_index = torch.nonzero(delete_index == 2).squeeze() if use_cuda: device = torch.device('cuda') delete_entity_index = all_triplets[delete_index, 0].view(-1).cpu().numpy() perturb_entity_index = np.array( list( set(np.arange(num_entity)) - set(delete_entity_index))) perturb_entity_index = torch.from_numpy( perturb_entity_index).to(device) perturb_entity_index = torch.cat( (subject.view(-1), perturb_entity_index)) else: delete_entity_index = all_triplets[delete_index, 0].view(-1).numpy() perturb_entity_index = np.array( list( set(np.arange(num_entity)) - set(delete_entity_index))) perturb_entity_index = torch.from_numpy(perturb_entity_index) perturb_entity_index = torch.cat( (subject.view(-1), perturb_entity_index)) # Score if score_function == 'DistMult': emb_ar = task_entity_embedding * relation_embeddings[relation] emb_ar = emb_ar.view(-1, 1, 1) emb_c = entity_embeddings[perturb_entity_index] emb_c = emb_c.transpose(0, 1).unsqueeze(1) out_prod = torch.bmm(emb_ar, emb_c) score = torch.sum(out_prod, dim=0) score = F.softmax(score, dim=1) elif score_function == 'TransE': head_embeddings = entity_embeddings[perturb_entity_index] relation_embedding = relation_embeddings[relation] tail_embedding = task_entity_embedding score = head_embeddings + relation_embedding - tail_embedding score = -torch.norm(score, p=2, dim=1) score = score.view(1, -1) score = F.softmax(score, dim=1) else: raise TypeError if use_cuda: target = torch.tensor(0).to(device) rank = utils.sort_and_rank(score, target) else: target = torch.tensor(0) rank = utils.sort_and_rank(score, target) return rank, (is_trans, is_subject, is_object) def calc_score(self, task_entity, total_task_entity, task_entity_embedding, entity_embeddings, relation_embeddings, test_triplet, all_triplets, use_cuda, score_function): num_entity = len(entity_embeddings) is_trans = False is_subject = False is_object = False if (test_triplet[0] in total_task_entity) and (test_triplet[2] in total_task_entity): is_trans = True if (test_triplet[0] == task_entity): is_subject = True subject = test_triplet[0] relation = test_triplet[1] object_ = test_triplet[2] subject_relation = torch.LongTensor([subject, relation]) if use_cuda: subject_relation = subject_relation.cuda() delete_index = torch.sum( self.head_relation_triplets == subject_relation, dim=1) delete_index = torch.nonzero(delete_index == 2).squeeze() if use_cuda: device = torch.device('cuda') delete_entity_index = all_triplets[delete_index, 2].view(-1).cpu().numpy() perturb_entity_index = np.array( list( set(np.arange(num_entity)) - set(delete_entity_index))) perturb_entity_index = torch.from_numpy( perturb_entity_index).to(device) perturb_entity_index = torch.cat( (object_.view(-1), perturb_entity_index)) else: delete_entity_index = all_triplets[delete_index, 2].view(-1).numpy() perturb_entity_index = np.array( list( set(np.arange(num_entity)) - set(delete_entity_index))) perturb_entity_index = torch.from_numpy(perturb_entity_index) perturb_entity_index = torch.cat( (object_.view(-1), perturb_entity_index)) # Score if score_function == 'DistMult': emb_ar = task_entity_embedding * relation_embeddings[relation] emb_ar = emb_ar.view(-1, 1, 1) emb_c = entity_embeddings[perturb_entity_index] emb_c = emb_c.transpose(0, 1).unsqueeze(1) out_prod = torch.bmm(emb_ar, emb_c) score = torch.sum(out_prod, dim=0) elif score_function == 'TransE': head_embedding = task_entity_embedding relation_embedding = relation_embeddings[relation] tail_embeddings = entity_embeddings[perturb_entity_index] score = -torch.norm( (head_embedding + relation_embedding - tail_embeddings), p=2, dim=1) score = score.view(1, -1) else: raise TypeError elif (test_triplet[2] == task_entity): is_object = True subject = test_triplet[0] relation = test_triplet[1] object_ = test_triplet[2] object_relation = torch.LongTensor([object_, relation]) if use_cuda: object_relation = object_relation.cuda() delete_index = torch.sum( self.tail_relation_triplets == object_relation, dim=1) delete_index = torch.nonzero(delete_index == 2).squeeze() if use_cuda: device = torch.device('cuda') delete_entity_index = all_triplets[delete_index, 0].view(-1).cpu().numpy() perturb_entity_index = np.array( list( set(np.arange(num_entity)) - set(delete_entity_index))) perturb_entity_index = torch.from_numpy( perturb_entity_index).to(device) perturb_entity_index = torch.cat( (subject.view(-1), perturb_entity_index)) else: delete_entity_index = all_triplets[delete_index, 0].view(-1).numpy() perturb_entity_index = np.array( list( set(np.arange(num_entity)) - set(delete_entity_index))) perturb_entity_index = torch.from_numpy(perturb_entity_index) perturb_entity_index = torch.cat( (subject.view(-1), perturb_entity_index)) # Score if score_function == 'DistMult': emb_ar = task_entity_embedding * relation_embeddings[relation] emb_ar = emb_ar.view(-1, 1, 1) emb_c = entity_embeddings[perturb_entity_index] emb_c = emb_c.transpose(0, 1).unsqueeze(1) out_prod = torch.bmm(emb_ar, emb_c) score = torch.sum(out_prod, dim=0) elif score_function == 'TransE': head_embeddings = entity_embeddings[perturb_entity_index] relation_embedding = relation_embeddings[relation] tail_embedding = task_entity_embedding score = head_embeddings + relation_embedding - tail_embedding score = -torch.norm(score, p=2, dim=1) score = score.view(1, -1) else: raise TypeError return score, (is_trans, is_subject, is_object) def experiment_name(self, args): exp_name = os.path.join('./checkpoints', self.args.exp_name) return exp_name
class Trainer(object): def __init__(self, args): super(Trainer, self).__init__() self.args = args self.exp_name = self.experiment_name(args) self.best_mrr = 0 self.use_cuda = args.gpu >= 0 and torch.cuda.is_available() if self.use_cuda: torch.cuda.set_device(args.gpu) self.entity2id, self.relation2id, self.train_triplets, self.valid_triplets, self.test_triplets = utils.load_data( './Dataset/raw_data/{}'.format(args.data)) self.filtered_triplets, self.meta_train_task_triplets, self.meta_valid_task_triplets, self.meta_test_task_triplets, \ self.meta_train_task_entity_to_triplets, self.meta_valid_task_entity_to_triplets, self.meta_test_task_entity_to_triplets \ = utils.load_processed_data('./Dataset/processed_data/{}'.format(args.data)) self.all_triplets = torch.LongTensor( np.concatenate((self.train_triplets, self.valid_triplets, self.test_triplets))) self.meta_task_entity = np.concatenate( (list(self.meta_valid_task_entity_to_triplets.keys()), list(self.meta_test_task_entity_to_triplets.keys()))) self.entities_list = np.delete(np.arange(len(self.entity2id)), self.meta_task_entity) self.load_pretrain_embedding() self.load_model() if self.use_cuda: self.model.cuda() self.all_triplets = self.all_triplets.cuda() self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay) def load_pretrain_embedding(self): self.embedding_size = int(self.args.pre_train_emb_size) if self.args.pre_train: pretrain_model_path = './Pretraining/{}'.format(self.args.data) entity_file_name = os.path.join( pretrain_model_path, '{}_entity_{}.npy'.format(self.args.pre_train_model, self.embedding_size)) relation_file_name = os.path.join( pretrain_model_path, '{}_relation_{}.npy'.format(self.args.pre_train_model, self.embedding_size)) self.pretrain_entity_embedding = torch.Tensor( np.load(entity_file_name)) self.pretrain_relation_embedding = torch.Tensor( np.load(relation_file_name)) else: self.pretrain_entity_embedding = None self.pretrain_relation_embedding = None def load_model(self): if self.args.model == 'TransGEN': self.model = TransGEN( self.embedding_size, self.embedding_size, len(self.entity2id), len(self.relation2id), args=self.args, entity_embedding=self.pretrain_entity_embedding, relation_embedding=self.pretrain_relation_embedding) else: raise ValueError("Model Name <{}> is Wrong".format( self.args.model)) meta_task_entity = torch.LongTensor(self.meta_task_entity) self.model.entity_embedding.weight.data[ meta_task_entity] = torch.zeros(len(meta_task_entity), self.embedding_size) # Meta-Learning for Long-Tail Tasks def cal_train_few(self, epoch): if self.args.model_tail == 'log': for i in range(self.args.max_few): if epoch < (self.args.n_epochs / (2**i)): continue return max(min(self.args.max_few, self.args.few + i - 1), self.args.few) return self.args.max_few else: return self.args.few def train(self): for epoch in trange(0, (self.args.n_epochs + 1), desc='Train Epochs', position=0): # Meta-Train self.model.train() train_task_pool = list( self.meta_train_task_entity_to_triplets.keys()) random.shuffle(train_task_pool) total_unseen_entity = [] total_unseen_entity_embedding = [] total_train_triplets = [] total_pos_samples = [] total_neg_samples = [] train_few = self.cal_train_few(epoch) for unseen_entity in train_task_pool[:self.args.num_train_entity]: triplets = self.meta_train_task_entity_to_triplets[ unseen_entity] random.shuffle(triplets) triplets = np.array(triplets) heads, relations, tails = triplets.transpose() train_triplets = triplets[:train_few] test_triplets = triplets[train_few:] if (len(triplets)) - train_few < 1: continue entities_list = self.entities_list false_candidates = np.array( list( set(entities_list) - set(np.concatenate((heads, tails))))) false_entities = np.random.choice( false_candidates, size=(len(triplets) - train_few) * self.args.negative_sample) pos_samples = test_triplets neg_samples = np.tile(pos_samples, (self.args.negative_sample, 1)) neg_samples[neg_samples[:, 0] == unseen_entity, 2] = false_entities[neg_samples[:, 0] == unseen_entity] neg_samples[neg_samples[:, 2] == unseen_entity, 0] = false_entities[neg_samples[:, 2] == unseen_entity] total_pos_samples.extend(pos_samples) total_neg_samples.extend(neg_samples) # Train (Inductive) unseen_entity_embedding = self.model(unseen_entity, train_triplets, self.use_cuda, is_trans=False) total_unseen_entity.append(unseen_entity) total_unseen_entity_embedding.append(unseen_entity_embedding) total_train_triplets.extend(train_triplets) total_unseen_entity = np.array(total_unseen_entity) total_unseen_entity_embedding = torch.cat( total_unseen_entity_embedding).view(-1, self.embedding_size) total_train_triplets = np.array(total_train_triplets) total_pos_samples = torch.LongTensor(np.array(total_pos_samples)) total_neg_samples = torch.LongTensor(np.array(total_neg_samples)) total_samples = torch.cat((total_pos_samples, total_neg_samples)) if self.use_cuda: total_samples = total_samples.cuda() # Train (Transductive) unseen_entity_embeddings, _, _ = self.model( total_unseen_entity, total_train_triplets, self.use_cuda, is_trans=True, total_unseen_entity_embedding=total_unseen_entity_embedding) loss = self.model.score_loss(total_unseen_entity, unseen_entity_embeddings, total_samples, target=None, use_cuda=self.use_cuda) # Test Update self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_norm) self.optimizer.step() # Meta-Valid if epoch % self.args.evaluate_every == 0: tqdm.write("Epochs-{}, Loss-{}".format(epoch, loss)) with torch.no_grad(): results = self.eval(eval_type='valid') mrr = results['total_mrr'] tqdm.write("Total MRR (filtered): {:.6f}".format( results['total_mrr'])) tqdm.write("Total Hits (filtered) @ {}: {:.6f}".format( 1, results['total_hits@1'])) tqdm.write("Total Hits (filtered) @ {}: {:.6f}".format( 3, results['total_hits@3'])) tqdm.write("Total Hits (filtered) @ {}: {:.6f}".format( 10, results['total_hits@10'])) if mrr > self.best_mrr: self.best_mrr = mrr torch.save( { 'state_dict': self.model.state_dict(), 'epoch': epoch }, './checkpoints/{}/best_mrr_model.pth'.format( self.exp_name)) checkpoint = torch.load('./checkpoints/{}/best_mrr_model.pth'.format( self.exp_name)) self.model.load_state_dict(checkpoint['state_dict']) print("Using best epoch: {}, {}".format(checkpoint['epoch'], self.exp_name)) # Meta-Test results = self.eval(eval_type='test') tqdm.write("Total MRR (filtered): {:.6f}".format(results['total_mrr'])) tqdm.write("Total Hits (filtered) @ {}: {:.6f}".format( 1, results['total_hits@1'])) tqdm.write("Total Hits (filtered) @ {}: {:.6f}".format( 3, results['total_hits@3'])) tqdm.write("Total Hits (filtered) @ {}: {:.6f}".format( 10, results['total_hits@10'])) def eval(self, eval_type='test'): self.model.eval() if eval_type == 'valid': test_task_dict = self.meta_valid_task_entity_to_triplets test_task_pool = list( self.meta_valid_task_entity_to_triplets.keys()) elif eval_type == 'test': test_task_dict = self.meta_test_task_entity_to_triplets test_task_pool = list( self.meta_test_task_entity_to_triplets.keys()) else: raise ValueError("Eval Type <{}> is Wrong".format(eval_type)) total_ranks = [] total_subject_ranks = [] total_object_ranks = [] total_unseen_entity = [] total_unseen_entity_embedding = [] total_train_triplets = [] total_test_triplets = [] total_test_triplets_dict = dict() for unseen_entity in test_task_pool: triplets = test_task_dict[unseen_entity] triplets = np.array(triplets) heads, relations, tails = triplets.transpose() train_triplets = triplets[:self.args.few] test_triplets = triplets[self.args.few:] if (len(triplets)) - self.args.few < 1: continue # Train (Inductive) unseen_entity_embedding = self.model(unseen_entity, train_triplets, use_cuda=self.use_cuda, is_trans=False) total_unseen_entity.append(unseen_entity) total_unseen_entity_embedding.append(unseen_entity_embedding) total_train_triplets.extend(train_triplets) total_test_triplets.extend(test_triplets) total_test_triplets_dict[unseen_entity] = torch.LongTensor( test_triplets) # Train (Transductive) total_unseen_entity = np.array(total_unseen_entity) total_unseen_entity_embedding = torch.cat( total_unseen_entity_embedding).view(-1, self.embedding_size) total_train_triplets = np.array(total_train_triplets) total_test_triplets = torch.LongTensor(total_test_triplets) unseen_entity_embeddings, _, _ = self.model( total_unseen_entity, total_train_triplets, use_cuda=self.use_cuda, is_trans=True, total_unseen_entity_embedding=total_unseen_entity_embedding) # Test total_unseen_entity = torch.from_numpy(total_unseen_entity) all_entity_embeddings = copy.deepcopy( self.model.entity_embedding.weight).detach() all_relation_embeddings = copy.deepcopy( self.model.relation_embedding).detach() ranks, ranks_s, ranks_o = utils.cal_trans_mrr( total_unseen_entity, unseen_entity_embeddings, all_entity_embeddings, all_relation_embeddings, total_test_triplets_dict, self.all_triplets, use_cuda=self.use_cuda, score_function=self.args.score_function) if len(ranks_s) != 0: total_subject_ranks.append(ranks_s) if len(ranks_o) != 0: total_object_ranks.append(ranks_o) total_ranks.append(ranks) results = {} # Subject total_subject_ranks = torch.cat(total_subject_ranks) total_subject_ranks += 1 results['subject_ranks'] = total_subject_ranks results['subject_mrr'] = torch.mean( 1.0 / total_subject_ranks.float()).item() for hit in [1, 3, 10]: avg_count = torch.mean((total_subject_ranks <= hit).float()) results['subject_hits@{}'.format(hit)] = avg_count.item() # Object total_object_ranks = torch.cat(total_object_ranks) total_object_ranks += 1 results['object_ranks'] = total_object_ranks results['object_mrr'] = torch.mean(1.0 / total_object_ranks.float()).item() for hit in [1, 3, 10]: avg_count = torch.mean((total_object_ranks <= hit).float()) results['object_hits@{}'.format(hit)] = avg_count.item() # Total total_ranks = torch.cat(total_ranks) total_ranks += 1 results['total_ranks'] = total_ranks results['total_mrr'] = torch.mean(1.0 / total_ranks.float()).item() for hit in [1, 3, 10]: avg_count = torch.mean((total_ranks <= hit).float()) results['total_hits@{}'.format(hit)] = avg_count.item() return results def experiment_name(self, args): ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime()) exp_name = str() exp_name += "Data={}_".format(args.data) exp_name += "Model={}_".format(args.model) exp_name += "Bases={}_".format(args.bases) exp_name += "DO={}_".format(args.dropout) exp_name += "NS={}_".format(args.negative_sample) exp_name += "Margin={}_".format(args.margin) exp_name += "Few={}_".format(args.few) exp_name += "LR={}_".format(args.lr) exp_name += "WD={}_".format(args.weight_decay) exp_name += "GN={}_".format(args.grad_norm) exp_name += "PT={}_".format(args.pre_train) exp_name += "PTM={}_".format(args.pre_train_model) exp_name += "PTES={}_".format(args.pre_train_emb_size) exp_name += "NE={}_".format(args.num_train_entity) exp_name += "FT={}_".format(args.fine_tune) exp_name += "SF={}_".format(args.score_function) exp_name += "TS={}".format(ts) if not args.debug: if not (os.path.isdir('./checkpoints/{}'.format(exp_name))): os.makedirs(os.path.join('./checkpoints/{}'.format(exp_name))) print("Make Directory {} in a Checkpoints Folder".format(exp_name)) return exp_name
class Trainer(object): def __init__(self, args): super(Trainer, self).__init__() self.args = args self.exp_name = self.experiment_name(args) self.best_roc = 0 self.use_cuda = args.gpu >= 0 and torch.cuda.is_available() if self.use_cuda: torch.cuda.set_device(args.gpu) self.filtered_triplets, self.meta_train_task_triplets, self.meta_valid_task_triplets, self.meta_test_task_triplets, \ self.meta_train_task_entity_to_triplets, self.meta_valid_task_entity_to_triplets, self.meta_test_task_entity_to_triplets \ = utils.load_processed_data('./Dataset/processed_data/{}'.format(args.data)) self.meta_task_entity = np.concatenate( (list(self.meta_train_task_entity_to_triplets.keys()), list(self.meta_valid_task_entity_to_triplets.keys()), list(self.meta_test_task_entity_to_triplets.keys()))) self.load_pretrain_embedding() self.load_model() print(self.model) if self.use_cuda: self.model.cuda() self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay) def load_pretrain_embedding(self): self.embedding_size = int(self.args.pre_train_emb_size) if self.args.pre_train: entity_file_name = './Pretraining/{}/{}_entity.npy'.format( self.args.data, self.args.pre_train_model) self.pretrain_entity_embedding = torch.Tensor( np.load(entity_file_name)) self.pretrain_relation_embedding = None else: self.pretrain_entity_embedding = None self.pretrain_relation_embedding = None def load_model(self): num_entities = None num_relations = None if self.args.data == 'BIOSNAP-sub': num_entities = 637 num_relations = 200 elif self.args.data == 'DeepDDI': num_entities = 1861 num_relations = 113 if self.args.model == 'TransGEN': self.model = TransGEN( self.embedding_size, self.embedding_size, num_entities, num_relations, args=self.args, entity_embedding=self.pretrain_entity_embedding, relation_embedding=self.pretrain_relation_embedding) else: raise ValueError("Model Name <{}> is Wrong".format( self.args.model)) meta_task_entity = torch.LongTensor(self.meta_task_entity) self.model.entity_embedding.weight.data[ meta_task_entity] = torch.zeros(len(meta_task_entity), self.embedding_size) def train(self): for epoch in trange(0, (self.args.n_epochs + 1), desc='Train Epochs', position=0): # Meta-Train self.model.train() train_task_pool = list( self.meta_train_task_entity_to_triplets.keys()) random.shuffle(train_task_pool) total_unseen_entity = [] total_unseen_entity_embedding = [] total_train_triplets = [] total_samples = [] for unseen_entity in train_task_pool[:self.args.num_train_entity]: triplets = self.meta_train_task_entity_to_triplets[ unseen_entity] random.shuffle(triplets) triplets = np.array(triplets) heads, relations, tails = triplets.transpose() train_triplets = triplets[:self.args.few] test_triplets = triplets[self.args.few:] if (len(triplets)) - self.args.few < 1: continue samples = test_triplets samples = torch.LongTensor(samples) if self.use_cuda: samples = samples.cuda() # Train (Inductive) unseen_entity_embedding = self.model(unseen_entity, train_triplets, self.use_cuda, is_trans=False) total_unseen_entity.append(unseen_entity) total_unseen_entity_embedding.append(unseen_entity_embedding) total_train_triplets.extend(train_triplets) total_samples.append(samples) total_unseen_entity = np.array(total_unseen_entity) total_unseen_entity_embedding = torch.cat( total_unseen_entity_embedding).view(-1, self.embedding_size) total_train_triplets = np.array(total_train_triplets) total_samples = torch.cat(total_samples, dim=0) # Train (Transductive) unseen_entity_embeddings, _, _ = self.model( total_unseen_entity, total_train_triplets, self.use_cuda, is_trans=True, total_unseen_entity_embedding=total_unseen_entity_embedding) loss = self.model.score_loss(total_unseen_entity, unseen_entity_embeddings, total_samples, target=None, use_cuda=self.use_cuda) # Test self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_norm) self.optimizer.step() # Meta-Valid if epoch % self.args.evaluate_every == 0: tqdm.write("Epochs-{}, Loss-{}".format(epoch, loss)) with torch.no_grad(): results = self.eval(eval_type='valid') roc = results['roc'] tqdm.write("Total PR (filtered): {:.6f}".format( results['pr'])) tqdm.write("Total ROC (filtered): {:.6f}".format( results['roc'])) tqdm.write("Total ACC (filtered): {:.6f}".format( results['acc'])) if roc > self.best_roc: self.best_roc = roc torch.save( { 'state_dict': self.model.state_dict(), 'epoch': epoch }, './checkpoints/{}/best_mrr_model.pth'.format( self.exp_name)) checkpoint = torch.load('./checkpoints/{}/best_mrr_model.pth'.format( self.exp_name)) self.model.load_state_dict(checkpoint['state_dict']) print("Using best epoch: {}, {}".format(checkpoint['epoch'], self.exp_name)) # Meta-Test results = self.eval(eval_type='test') tqdm.write("Total PR (filtered): {:.6f}".format(results['pr'])) tqdm.write("Total ROC (filtered): {:.6f}".format(results['roc'])) tqdm.write("Total ACC (filtered): {:.6f}".format(results['acc'])) def eval(self, eval_type='test'): self.model.eval() if eval_type == 'valid': test_task_dict = self.meta_valid_task_entity_to_triplets test_task_pool = list( self.meta_valid_task_entity_to_triplets.keys()) elif eval_type == 'test': test_task_dict = self.meta_test_task_entity_to_triplets test_task_pool = list( self.meta_test_task_entity_to_triplets.keys()) else: raise ValueError("Eval Type <{}> is Wrong".format(eval_type)) total_unseen_entity = [] total_unseen_entity_embedding = [] total_train_triplets = [] total_test_triplets = [] total_test_triplets_dict = dict() for unseen_entity in test_task_pool: triplets = test_task_dict[unseen_entity] triplets = np.array(triplets) heads, relations, tails = triplets.transpose() train_triplets = triplets[:self.args.few] test_triplets = triplets[self.args.few:] if (len(triplets)) - self.args.few < 1: continue # Train (Inductive) unseen_entity_embedding = self.model(unseen_entity, train_triplets, use_cuda=self.use_cuda, is_trans=False) total_unseen_entity.append(unseen_entity) total_unseen_entity_embedding.append(unseen_entity_embedding) total_train_triplets.extend(train_triplets) total_test_triplets.extend(test_triplets) total_test_triplets_dict[unseen_entity] = torch.LongTensor( test_triplets) # Train (Transductive) total_unseen_entity = np.array(total_unseen_entity) total_unseen_entity_embedding = torch.cat( total_unseen_entity_embedding).view(-1, self.embedding_size) total_train_triplets = np.array(total_train_triplets) samples = total_test_triplets samples = torch.LongTensor(samples) if self.use_cuda: samples = samples.cuda() unseen_entity_embeddings, _, _ = self.model( total_unseen_entity, total_train_triplets, use_cuda=self.use_cuda, is_trans=True, total_unseen_entity_embedding=total_unseen_entity_embedding) y_prob, y = self.model.predict(total_unseen_entity, unseen_entity_embeddings, samples, target=None, use_cuda=self.use_cuda) y_prob = y_prob.detach().cpu().numpy() y = y.detach().cpu().numpy() results = utils.metric_report(y, y_prob) return results def experiment_name(self, args): ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime()) exp_name = str() exp_name += "Data={}_".format(args.data) exp_name += "Model={}_".format(args.model) exp_name += "Bases={}_".format(args.bases) exp_name += "DO={}_".format(args.dropout) exp_name += "Few={}_".format(args.few) exp_name += "LR={}_".format(args.lr) exp_name += "WD={}_".format(args.weight_decay) exp_name += "GN={}_".format(args.grad_norm) exp_name += "PT={}_".format(args.pre_train) exp_name += "PTM={}_".format(args.pre_train_model) exp_name += "PTES={}_".format(args.pre_train_emb_size) exp_name += "FT={}_".format(args.fine_tune) exp_name += "TS={}".format(ts) if not args.debug: if not (os.path.isdir('./checkpoints/{}'.format(exp_name))): os.makedirs(os.path.join('./checkpoints/{}'.format(exp_name))) print("Make Directory {} in a Checkpoints Folder".format(exp_name)) return exp_name