def train(epoch): print("Epoch", epoch) t = time.time() model.train(True) torch.set_grad_enabled(True) eloss = 0 for batch_idx, instance in enumerate(train_generator): pos, neg, pht_bef, ptt_bef, nht_bef, ntt_bef = instance pos = pos.to(device) neg = neg.to(device) # text information pht = list(map(lambda x: x.to(device), pht_bef[0:3])) ptt = list(map(lambda x: x.to(device), ptt_bef[0:3])) nht = list(map(lambda x: x.to(device), nht_bef[0:3])) ntt = list(map(lambda x: x.to(device), ntt_bef[0:3])) batch_nodes, batch_adj = get_subgraph(pos, train_triple_dict, graph) # get relative location according to the batch_nodes shifted_pos, shifted_neg = convert_index([pos, neg], batch_nodes) batch_nodes = torch.LongTensor(batch_nodes.tolist()).to(device) batch_adj = torch.from_numpy(batch_adj).to(device) shifted_pos = torch.LongTensor(shifted_pos).to(device) shifted_neg = torch.LongTensor(shifted_neg).to(device) score_pos = model(batch_nodes, batch_adj, pos, shifted_pos, pht[0], pht[1], pht[2], ptt[0], ptt[1], ptt[2]) score_neg = model(batch_nodes, batch_adj, neg, shifted_neg, nht[0], nht[1], nht[2], ntt[0], ntt[1], ntt[2]) loss_train = F.margin_ranking_loss(score_pos, score_neg, y, margin=args.margin) sys.stdout.write( '%d batches processed. current train batch loss: %f\r' % (batch_idx, loss_train.item())) eloss += loss_train.item() loss_train.backward() del batch_nodes, batch_adj, shifted_pos, shifted_neg, pos, neg, pht_bef, ptt_bef, nht_bef, ntt_bef optimizer.step() if batch_idx % 500 == 0: gc.collect() print('\n') print('Epoch: {:04d}'.format(epoch + 1), 'loss_train: {:.4f}'.format(eloss / (batch_idx + 1)), 'time: {:.4f}s'.format(time.time() - t)) return eloss
def test(): print('Testing...') # model.cpu() model.eval() output_file = open('test_top10.txt', 'w') hitk_all = 0 mean_rank_head = [] mean_rank_tail = [] all_named_triples = set() for batch_idx, (new_head_triple, new_tail_triple, correct) in enumerate(test_set): t_current = time.time() print("Current: ", batch_idx, "Total: ", len(test_set)) test = LinkTestDataset(new_head_triple, new_tail_triple, text_f, id2ent) test_generator = data.DataLoader(test, **params_test) scores_heads = [] scores_tails = [] for current_idx, instance in enumerate(test_generator): head, tail, hht_bef, htt_bef, tht_bef, ttt_bef = instance head = head.to(device) tail = tail.to(device) # text information hht = list(map(lambda x: x.to(device), hht_bef[0:3])) htt = list(map(lambda x: x.to(device), htt_bef[0:3])) tht = list(map(lambda x: x.to(device), tht_bef[0:3])) ttt = list(map(lambda x: x.to(device), ttt_bef[0:3])) batch_nodes, batch_adj = get_subgraph(head, train_triple_dict, graph) # get relative location according to the batch_nodes shifted_head = convert_index([head], batch_nodes) batch_nodes = torch.LongTensor(batch_nodes.tolist()).to(device) batch_adj = torch.from_numpy(batch_adj).to(device) shifted_head = torch.LongTensor(shifted_head[0]).to(device) score_head = model(batch_nodes, batch_adj, head, shifted_head, hht[0], hht[1], hht[2], htt[0], htt[1], htt[2]) scores_heads.append(score_head.detach()) del batch_nodes, batch_adj batch_nodes, batch_adj = get_subgraph(tail, train_triple_dict, graph) # get relative location according to the batch_nodes shifted_tail = convert_index([tail], batch_nodes) shifted_tail = torch.LongTensor(shifted_tail[0]).to(device) batch_nodes = torch.LongTensor(batch_nodes.tolist()).to(device) batch_adj = torch.from_numpy(batch_adj).to(device) score_tail = model(batch_nodes, batch_adj, tail, shifted_tail, tht[0], tht[1], tht[2], ttt[0], ttt[1], ttt[2]) scores_tails.append(score_tail.detach()) del batch_nodes, batch_adj, head, shifted_head, hht, htt, tail, shifted_tail, tht, ttt sys.stdout.write('%d batches processed.\r' % (current_idx)) # get head scores scores_head = torch.cat(scores_heads, 0) scores_head = torch.sum(scores_head, 1).squeeze() assert scores_head.size(0) == num_ent sorted_head_idx = np.argsort(scores_head.tolist()) topk_head = new_head_triple[sorted_head_idx][:10] #get tail socres scores_tail = torch.cat(scores_tails, 0) scores_tail = torch.sum(scores_tail, 1).squeeze() sorted_tail_idx = np.argsort(scores_tail.tolist()) topk_tail = new_tail_triple[sorted_tail_idx][:10] # predict and output top 10 triples named_triples_head = convert_idx2name(topk_head, id2ent, ent2name, id2rel) named_triples_tail = convert_idx2name(topk_tail, id2ent, ent2name, id2rel) write_triples(named_triples_head, output_file) write_triples(named_triples_tail, output_file) mean_rank_result_head = mean_rank(new_head_triple, sorted_head_idx, correct, 0) mean_rank_result_tail = mean_rank(new_tail_triple, sorted_tail_idx, correct, 1) if mean_rank_result_head <= 10: hitk_all += 1 if mean_rank_result_tail <= 10: hitk_all += 1 mean_rank_head.append(mean_rank_result_head) mean_rank_tail.append(mean_rank_result_tail) del test gc.collect() output_file.close() print('Final mean rank for head is %f' % (np.mean(mean_rank_head))) print('Final median rank for head is %f' % np.median(mean_rank_head)) print('Final mean rank for tail is %f' % (np.mean(mean_rank_tail))) print('Final median rank for tail is %f' % np.median(mean_rank_tail)) print('Final hit10 is %f' % (hitk_all / (len(mean_rank_tail) + 1) / 2)) return hitk_all, mean_rank_head, mean_rank_tail