def test(args): # load data num_nodes, num_rels = utils.get_total_number(args.dataset_path, 'stat.txt') test_data, test_times = utils.load_hexaruples(args.dataset_path, 'test.txt') total_data, total_times = utils.load_hexaruples(args.dataset_path, 'train.txt', 'valid.txt', 'test.txt') model_dir = 'models/' + args.dataset + '/{}-{}-{}-{}'.format( args.dropout, args.n_hidden, args.gamma, args.num_k) model_state_file = model_dir + '/epoch-{}.pth'.format(args.epoch) # check cuda use_cuda = args.gpu >= 0 and torch.cuda.is_available() if use_cuda: torch.cuda.set_device(args.gpu) torch.cuda.manual_seed_all(999) model = DArtNet(num_nodes, args.n_hidden, num_rels, model=args.model, seq_len=args.seq_len, num_k=args.num_k, gamma=args.gamma) if use_cuda: model.cuda() test_sub_entity = '/test_entity_s_history_data.txt' test_sub_rel = '/test_rel_s_history_data.txt' test_sub_att = '/test_att_s_history_data.txt' test_sub_self_att = '/test_self_att_s_history_data.txt' test_ob_entity = '/test_entity_o_history_data.txt' test_ob_rel = '/test_rel_o_history_data.txt' test_ob_att = '/test_att_o_history_data.txt' test_ob_self_att = '/test_self_att_o_history_data.txt' with open(args.dataset_path + test_sub_entity, 'rb') as f: entity_s_history_data_test = pickle.load(f) with open(args.dataset_path + test_sub_rel, 'rb') as f: rel_s_history_data_test = pickle.load(f) with open(args.dataset_path + test_sub_att, 'rb') as f: att_s_history_data_test = pickle.load(f) with open(args.dataset_path + test_sub_self_att, 'rb') as f: self_att_s_history_data_test = pickle.load(f) with open(args.dataset_path + test_ob_entity, 'rb') as f: entity_o_history_data_test = pickle.load(f) with open(args.dataset_path + test_ob_rel, 'rb') as f: rel_o_history_data_test = pickle.load(f) with open(args.dataset_path + test_ob_att, 'rb') as f: att_o_history_data_test = pickle.load(f) with open(args.dataset_path + test_ob_self_att, 'rb') as f: self_att_o_history_data_test = pickle.load(f) print(f'\nstart testing model file : {model_state_file}') checkpoint = torch.load(model_state_file, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['state_dict']) model.init_history() model.latest_time = checkpoint['latest_time'] print("Using epoch: {}".format(checkpoint['epoch'])) total_data = torch.from_numpy(total_data) test_data = torch.from_numpy(test_data) model.eval() total_att_sub_loss = 0 total_ranks = np.array([]) total_ranks_filter = np.array([]) ranks = [] with torch.no_grad(): latest_time = test_times[0] j = 0 while j < len(test_data): k = j while k < len(test_data): if test_data[k][-1] == test_data[j][-1]: k += 1 else: break start = j while start < k: end = min(k, start + args.batch_size) batch_data = test_data[start:end].clone() s_hist = entity_s_history_data_test[start:end].copy() o_hist = entity_o_history_data_test[start:end].copy() rel_s_hist = rel_s_history_data_test[start:end].copy() rel_o_hist = rel_o_history_data_test[start:end].copy() att_s_hist = att_s_history_data_test[start:end].copy() att_o_hist = att_o_history_data_test[start:end].copy() self_att_s_hist = self_att_s_history_data_test[start:end].copy( ) self_att_o_hist = self_att_o_history_data_test[start:end].copy( ) if use_cuda: batch_data = batch_data.cuda() loss_sub = model.predict(batch_data, s_hist, rel_s_hist, att_s_hist, self_att_s_hist, o_hist, rel_o_hist, att_o_hist, self_att_o_hist) total_att_sub_loss += (loss_sub.item() * (end - start + 1)) start += args.batch_size for i in range(j, k): batch_data = test_data[i].clone() s_hist = entity_s_history_data_test[i].copy() o_hist = entity_o_history_data_test[i].copy() rel_s_hist = rel_s_history_data_test[i].copy() rel_o_hist = rel_o_history_data_test[i].copy() att_s_hist = att_s_history_data_test[i].copy() att_o_hist = att_o_history_data_test[i].copy() self_att_s_hist = self_att_s_history_data_test[i].copy() self_att_o_hist = self_att_o_history_data_test[i].copy() if use_cuda: batch_data = batch_data.cuda() ranks_pred = model.evaluate_filter(batch_data, s_hist, rel_s_hist, att_s_hist, self_att_s_hist, o_hist, rel_o_hist, att_o_hist, self_att_o_hist, total_data) total_ranks_filter = np.concatenate( (total_ranks_filter, ranks_pred)) j = k ranks.append(total_ranks_filter) for rank in ranks: total_ranks = np.concatenate((total_ranks, rank)) mrr = np.mean(1.0 / total_ranks) mr = np.mean(total_ranks) hits = [] for hit in [1, 3, 10]: avg_count = np.mean((total_ranks <= hit)) hits.append(avg_count) print("Hits (filtered) @ {}: {:.6f}".format(hit, avg_count)) print("MRR (filtered): {:.6f}".format(mrr)) print("MR (filtered): {:.6f}".format(mr)) print("test att sub Loss: {:.6f}".format(total_att_sub_loss / (len(test_data)))) result_epoch = result(epoch=args.epoch, MRR=100 * mrr, sub_att_loss=total_att_sub_loss / len(test_data), MR=mr, Hits1=100 * hits[0], Hits3=100 * hits[1], Hits10=100 * hits[2]) result_dict[args.epoch] = result_epoch
def train(args): # load data num_nodes, num_rels = utils.get_total_number('./data/' + args.dataset, 'stat.txt') train_data, train_times = utils.load_quadruples('./data/' + args.dataset, 'train.txt') valid_data, valid_times = utils.load_quadruples('./data/' + args.dataset, 'valid.txt') total_data, total_times = utils.load_quadruples('./data/' + args.dataset, 'train.txt', 'valid.txt', 'test.txt') # check cuda use_cuda = args.gpu >= 0 and torch.cuda.is_available() if use_cuda: torch.cuda.set_device(args.gpu) torch.cuda.manual_seed_all(999) os.makedirs('models', exist_ok=True) if args.model == 0: model_state_file = 'models/' + args.dataset + 'attn.pth' elif args.model == 1: model_state_file = 'models/' + args.dataset + 'mean.pth' elif args.model == 2: model_state_file = 'models/' + args.dataset + 'gcn.pth' print("start training...") model = RENet(num_nodes, args.n_hidden, num_rels, dropout=args.dropout, model=args.model, seq_len=args.seq_len) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.00001) if use_cuda: model.cuda() with open('./data/' + args.dataset + '/train_history_sub.txt', 'rb') as f: s_history = pickle.load(f) with open('./data/' + args.dataset + '/train_history_ob.txt', 'rb') as f: o_history = pickle.load(f) with open('./data/' + args.dataset + '/dev_history_sub.txt', 'rb') as f: s_history_valid = pickle.load(f) with open('./data/' + args.dataset + '/dev_history_ob.txt', 'rb') as f: o_history_valid = pickle.load(f) valid_data = torch.from_numpy(valid_data) epoch = 0 best_mrr = 0 while True: model.train() if epoch == args.max_epochs: break epoch += 1 loss_epoch = 0 t0 = time.time() train_data, s_history, o_history = shuffle(train_data, s_history, o_history) i = 0 for batch_data, s_hist, o_hist in utils.make_batch( train_data, s_history, o_history, args.batch_size): batch_data = torch.from_numpy(batch_data) if use_cuda: batch_data = batch_data.cuda() loss = model.get_loss(batch_data, s_hist, o_hist) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm) # clip gradients optimizer.step() optimizer.zero_grad() loss_epoch += loss.item() i += 1 t3 = time.time() print("Epoch {:04d} | Loss {:.4f} | time {:.4f}".format( epoch, loss_epoch / (len(train_data) / args.batch_size), t3 - t0)) if epoch % 1 == 0: model.eval() total_loss = 0 total_ranks = np.array([]) model.init_history() model.latest_time = valid_data[0][3] for i in range(len(valid_data)): batch_data = valid_data[i] s_hist = s_history_valid[i] o_hist = o_history_valid[i] if use_cuda: batch_data = batch_data.cuda() with torch.no_grad(): ranks, loss = model.evaluate(batch_data, s_hist, o_hist) total_ranks = np.concatenate((total_ranks, ranks)) total_loss += loss.item() mrr = np.mean(1.0 / total_ranks) mr = np.mean(total_ranks) hits = [] for hit in [1, 3, 10]: avg_count = np.mean((total_ranks <= hit)) hits.append(avg_count) print("valid Hits (raw) @ {}: {:.6f}".format(hit, avg_count)) print("valid MRR (raw): {:.6f}".format(mrr)) print("valid MR (raw): {:.6f}".format(mr)) print("valid Loss: {:.6f}".format(total_loss / (len(valid_data)))) if mrr > best_mrr: best_mrr = mrr torch.save( { 'state_dict': model.state_dict(), 'epoch': epoch, 's_hist': model.s_hist_test, 's_cache': model.s_his_cache, 'o_hist': model.o_hist_test, 'o_cache': model.o_his_cache, 'latest_time': model.latest_time }, model_state_file) print("training done")
def test(args): # load data num_nodes, num_rels = utils.get_total_number('./data/' + args.dataset, 'stat.txt') if args.dataset == 'icews_know': train_data, train_times = utils.load_quadruples( './data/' + args.dataset, 'train.txt') valid_data, valid_times = utils.load_quadruples( './data/' + args.dataset, 'test.txt') test_data, test_times = utils.load_quadruples('./data/' + args.dataset, 'test.txt') total_data, total_times = utils.load_quadruples( './data/' + args.dataset, 'train.txt', 'test.txt') else: train_data, train_times = utils.load_quadruples( './data/' + args.dataset, 'train.txt') valid_data, valid_times = utils.load_quadruples( './data/' + args.dataset, 'valid.txt') test_data, test_times = utils.load_quadruples('./data/' + args.dataset, 'test.txt') total_data, total_times = utils.load_quadruples( './data/' + args.dataset, 'train.txt', 'valid.txt', 'test.txt') # check cuda use_cuda = args.gpu >= 0 and torch.cuda.is_available() if use_cuda: torch.cuda.set_device(args.gpu) torch.cuda.manual_seed_all(999) model_state_file = 'models/' + args.dataset + '/rgcn.pth' model_graph_file = 'models/' + args.dataset + '/rgcn_graph.pth' model_state_global_file2 = 'models/' + args.dataset + '/max' + str( args.maxpool) + 'rgcn_global2.pth' model = RENet(num_nodes, args.n_hidden, num_rels, model=args.model, seq_len=args.seq_len, num_k=args.num_k) global_model = RENet_global(num_nodes, args.n_hidden, num_rels, model=args.model, seq_len=args.seq_len, num_k=args.num_k, maxpool=args.maxpool) if use_cuda: model.cuda() global_model.cuda() with open('data/' + args.dataset + '/test_history_sub.txt', 'rb') as f: s_history_test_data = pickle.load(f) with open('data/' + args.dataset + '/test_history_ob.txt', 'rb') as f: o_history_test_data = pickle.load(f) s_history_test = s_history_test_data[0] s_history_test_t = s_history_test_data[1] o_history_test = o_history_test_data[0] o_history_test_t = o_history_test_data[1] print("\nstart testing:") checkpoint = torch.load(model_state_file, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['state_dict']) model.s_hist_test = checkpoint['s_hist'] model.s_his_cache = checkpoint['s_cache'] model.o_hist_test = checkpoint['o_hist'] model.o_his_cache = checkpoint['o_cache'] model.latest_time = checkpoint['latest_time'] if args.dataset == "icews_know": model.latest_time = torch.LongTensor([4344])[0] model.global_emb = checkpoint['global_emb'] model.s_hist_test_t = checkpoint['s_hist_t'] model.s_his_cache_t = checkpoint['s_cache_t'] model.o_hist_test_t = checkpoint['o_hist_t'] model.o_his_cache_t = checkpoint['o_cache_t'] with open(model_graph_file, 'rb') as f: model.graph_dict = pickle.load(f) checkpoint_global = torch.load(model_state_global_file2, map_location=lambda storage, loc: storage) global_model.load_state_dict(checkpoint_global['state_dict']) print("Using best epoch: {}".format(checkpoint['epoch'])) total_data = torch.from_numpy(total_data) test_data = torch.from_numpy(test_data) model.eval() global_model.eval() total_loss = 0 total_ranks = np.array([]) total_ranks_filter = np.array([]) ranks = [] for ee in range(num_nodes): while len(model.s_hist_test[ee]) > args.seq_len: model.s_hist_test[ee].pop(0) model.s_hist_test_t[ee].pop(0) while len(model.o_hist_test[ee]) > args.seq_len: model.o_hist_test[ee].pop(0) model.o_hist_test_t[ee].pop(0) if use_cuda: total_data = total_data.cuda() latest_time = test_times[0] for i in range(len(test_data)): batch_data = test_data[i] s_hist = s_history_test[i] o_hist = o_history_test[i] if args.model == 3: s_hist_t = s_history_test_t[i] o_hist_t = o_history_test_t[i] if latest_time != batch_data[3]: ranks.append(total_ranks_filter) latest_time = batch_data[3] total_ranks_filter = np.array([]) if use_cuda: batch_data = batch_data.cuda() with torch.no_grad(): # Filtered metric if args.raw: ranks_filter, loss = model.evaluate(batch_data, (s_hist, s_hist_t), (o_hist, o_hist_t), global_model) else: ranks_filter, loss = model.evaluate_filter( batch_data, (s_hist, s_hist_t), (o_hist, o_hist_t), global_model, total_data) total_ranks_filter = np.concatenate( (total_ranks_filter, ranks_filter)) total_loss += loss.item() ranks.append(total_ranks_filter) for rank in ranks: total_ranks = np.concatenate((total_ranks, rank)) mrr = np.mean(1.0 / total_ranks) mr = np.mean(total_ranks) hits = [] for hit in [1, 3, 10]: avg_count = np.mean((total_ranks <= hit)) hits.append(avg_count) print("Hits (filtered) @ {}: {:.6f}".format(hit, avg_count)) print("MRR (filtered): {:.6f}".format(mrr)) print("MR (filtered): {:.6f}".format(mr))
def train(args): # load data num_nodes, num_rels = utils.get_total_number('./data/' + args.dataset, 'stat.txt') if args.dataset == 'icews_know': train_data, train_times = utils.load_quadruples('./data/' + args.dataset, 'train.txt') valid_data, valid_times = utils.load_quadruples('./data/' + args.dataset, 'test.txt') test_data, test_times = utils.load_quadruples('./data/' + args.dataset, 'test.txt') total_data, total_times = utils.load_quadruples('./data/' + args.dataset, 'train.txt', 'test.txt') else: train_data, train_times = utils.load_quadruples('./data/' + args.dataset, 'train.txt') valid_data, valid_times = utils.load_quadruples('./data/' + args.dataset, 'valid.txt') test_data, test_times = utils.load_quadruples('./data/' + args.dataset, 'test.txt') total_data, total_times = utils.load_quadruples('./data/' + args.dataset, 'train.txt', 'valid.txt','test.txt') # check cuda use_cuda = args.gpu >= 0 and torch.cuda.is_available() seed = 999 np.random.seed(seed) torch.manual_seed(seed) if use_cuda: torch.cuda.set_device(args.gpu) os.makedirs('models', exist_ok=True) os.makedirs('models/'+ args.dataset, exist_ok=True) model_state_file = 'models/' + args.dataset + '/rgcn.pth' model_graph_file = 'models/' + args.dataset + '/rgcn_graph.pth' model_state_global_file2 = 'models/' + args.dataset + '/max' + str(args.maxpool) + 'rgcn_global2.pth' model_state_global_file = 'models/' + args.dataset + '/max' + str(args.maxpool) + 'rgcn_global.pth' model_state_file_backup = 'models/' + args.dataset + '/rgcn_backup.pth' print("start training...") model = RENet(num_nodes, args.n_hidden, num_rels, dropout=args.dropout, model=args.model, seq_len=args.seq_len, num_k=args.num_k) global_model = RENet_global(num_nodes, args.n_hidden, num_rels, dropout=args.dropout, model=args.model, seq_len=args.seq_len, num_k=args.num_k, maxpool=args.maxpool) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.00001) checkpoint_global = torch.load(model_state_global_file, map_location=lambda storage, loc: storage) global_model.load_state_dict(checkpoint_global['state_dict']) global_emb = checkpoint_global['global_emb'] model.global_emb = global_emb if use_cuda: model.cuda() global_model.cuda() train_sub = '/train_history_sub.txt' train_ob = '/train_history_ob.txt' if args.dataset == 'icews_know': valid_sub = '/test_history_sub.txt' valid_ob = '/test_history_ob.txt' else: valid_sub = '/dev_history_sub.txt' valid_ob = '/dev_history_ob.txt' with open('./data/' + args.dataset+'/train_graphs.txt', 'rb') as f: graph_dict = pickle.load(f) model.graph_dict = graph_dict with open('data/' + args.dataset+'/test_history_sub.txt', 'rb') as f: s_history_test_data = pickle.load(f) with open('data/' + args.dataset+'/test_history_ob.txt', 'rb') as f: o_history_test_data = pickle.load(f) s_history_test = s_history_test_data[0] s_history_test_t = s_history_test_data[1] o_history_test = o_history_test_data[0] o_history_test_t = o_history_test_data[1] with open('./data/' + args.dataset+train_sub, 'rb') as f: s_history_data = pickle.load(f) with open('./data/' + args.dataset+train_ob, 'rb') as f: o_history_data = pickle.load(f) with open('./data/' + args.dataset+valid_sub, 'rb') as f: s_history_valid_data = pickle.load(f) with open('./data/' + args.dataset+valid_ob, 'rb') as f: o_history_valid_data = pickle.load(f) valid_data = torch.from_numpy(valid_data) s_history = s_history_data[0] s_history_t = s_history_data[1] o_history = o_history_data[0] o_history_t = o_history_data[1] s_history_valid = s_history_valid_data[0] s_history_valid_t = s_history_valid_data[1] o_history_valid = o_history_valid_data[0] o_history_valid_t = o_history_valid_data[1] total_data = torch.from_numpy(total_data) if use_cuda: total_data = total_data.cuda() epoch = 0 best_mrr = 0 while True: print('training starting') model.train() if epoch == args.max_epochs: break epoch += 1 loss_epoch = 0 t0 = time.time() print('training time captured') train_data_shuffle, s_history_shuffle, s_history_t_shuffle, o_history_shuffle, o_history_t_shuffle = shuffle(train_data, s_history, s_history_t, o_history, o_history_t) print('training data formatted') for batch_data, s_hist, s_hist_t, o_hist, o_hist_t in utils.make_batch2(train_data_shuffle, s_history_shuffle, s_history_t_shuffle, o_history_shuffle, o_history_t_shuffle, args.batch_size): # break batch_data = torch.from_numpy(batch_data).long() if use_cuda: batch_data = batch_data.cuda() print('batch instance preprocessing') loss_s = model(batch_data, (s_hist, s_hist_t), (o_hist, o_hist_t), graph_dict, subject=True) loss_o = model(batch_data, (s_hist, s_hist_t), (o_hist, o_hist_t), graph_dict, subject=False) loss = loss_s + loss_o loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm) # clip gradients optimizer.step() optimizer.zero_grad() loss_epoch += loss.item() t3 = time.time() print("Epoch {:04d} | Loss {:.4f} | time {:.4f}". format(epoch, loss_epoch/(len(train_data)/args.batch_size), t3 - t0)) ## VALIDATION if epoch % args.valid_every == 0 and epoch >= int(args.max_epochs/2): model.eval() global_model.eval() total_loss = 0 total_ranks = np.array([]) model.init_history(train_data, (s_history, s_history_t), (o_history, o_history_t), valid_data, (s_history_valid, s_history_valid_t), (o_history_valid, o_history_valid_t), test_data, (s_history_test, s_history_test_t), (o_history_test, o_history_test_t)) model.latest_time = valid_data[0][3] for i in range(len(valid_data)): batch_data = valid_data[i] s_hist = s_history_valid[i] o_hist = o_history_valid[i] s_hist_t = s_history_valid_t[i] o_hist_t = o_history_valid_t[i] if use_cuda: batch_data = batch_data.cuda() with torch.no_grad(): ranks, loss = model.evaluate_filter(batch_data, (s_hist, s_hist_t), (o_hist, o_hist_t), global_model, total_data) total_ranks = np.concatenate((total_ranks, ranks)) total_loss += loss.item() mrr = np.mean(1.0 / total_ranks) mr = np.mean(total_ranks) hits = [] for hit in [1, 3, 10]: avg_count = np.mean((total_ranks <= hit)) hits.append(avg_count) print("valid Hits (filtered) @ {}: {:.6f}".format(hit, avg_count)) print("valid MRR (filtered): {:.6f}".format(mrr)) print("valid MR (filtered): {:.6f}".format(mr)) print("valid Loss: {:.6f}".format(total_loss / (len(valid_data)))) if mrr > best_mrr: best_mrr = mrr torch.save({'state_dict': model.state_dict(), 'epoch': epoch, 's_hist': model.s_hist_test, 's_cache': model.s_his_cache, 'o_hist': model.o_hist_test, 'o_cache': model.o_his_cache, 's_hist_t': model.s_hist_test_t, 's_cache_t': model.s_his_cache_t, 'o_hist_t': model.o_hist_test_t, 'o_cache_t': model.o_his_cache_t, 'latest_time': model.latest_time, 'global_emb': model.global_emb}, model_state_file) torch.save({'state_dict': global_model.state_dict(), 'epoch': epoch, 's_hist': model.s_hist_test, 's_cache': model.s_his_cache, 'o_hist': model.o_hist_test, 'o_cache': model.o_his_cache, 's_hist_t': model.s_hist_test_t, 's_cache_t': model.s_his_cache_t, 'o_hist_t': model.o_hist_test_t, 'o_cache_t': model.o_his_cache_t, 'latest_time': model.latest_time, 'global_emb': global_model.global_emb}, model_state_global_file2) with open(model_graph_file, 'wb') as fp: pickle.dump(model.graph_dict, fp) print("training done")
recall_list = [] f1_list = [] f2_list = [] hloss_list = [] iterations = 0 while iterations < args.runs: iterations += 1 print( '****************** iterations ', iterations, ) if iterations == 1: print("loading data...") num_nodes, num_rels = utils.get_total_number(args.dp + args.dataset, 'stat.txt') with open('{}{}/100.w_emb'.format(args.dp, args.dataset), 'rb') as f: word_embeds = pickle.load(f, encoding="latin1") word_embeds = torch.FloatTensor(word_embeds) vocab_size = word_embeds.size(0) train_dataset_loader = DistData(args.dp, args.dataset, num_nodes, num_rels, set_name='train') valid_dataset_loader = DistData(args.dp, args.dataset, num_nodes, num_rels,
def test(args): # load data num_nodes, num_rels = utils.get_total_number('./data/' + args.dataset, 'stat.txt') test_data, test_times = utils.load_quadruples('./data/' + args.dataset, 'test.txt') total_data, total_times = utils.load_quadruples('./data/' + args.dataset, 'train.txt', 'valid.txt', 'test.txt') # check cuda use_cuda = args.gpu >= 0 and torch.cuda.is_available() if use_cuda: torch.cuda.set_device(args.gpu) torch.cuda.manual_seed_all(999) os.makedirs('models', exist_ok=True) if args.model == 0: model_state_file = 'models/' + args.dataset + 'attn.pth' elif args.model == 1: model_state_file = 'models/' + args.dataset + 'mean.pth' elif args.model == 2: model_state_file = 'models/' + args.dataset + 'gcn.pth' model = RENet(num_nodes, args.n_hidden, num_rels, model=args.model, seq_len=args.seq_len) if use_cuda: model.cuda() with open('./data/' + args.dataset + '/test_history_sub.txt', 'rb') as f: s_history_test = pickle.load(f) with open('./data/' + args.dataset + '/test_history_ob.txt', 'rb') as f: o_history_test = pickle.load(f) print("\nstart testing:") checkpoint = torch.load(model_state_file, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['state_dict']) model.s_hist_test = checkpoint['s_hist'] model.s_his_cache = checkpoint['s_cache'] model.o_hist_test = checkpoint['o_hist'] model.o_his_cache = checkpoint['o_cache'] model.latest_time = checkpoint['latest_time'] print("Using best epoch: {}".format(checkpoint['epoch'])) total_data = torch.from_numpy(total_data) test_data = torch.from_numpy(test_data) model.eval() total_loss = 0 total_ranks = np.array([]) total_ranks_filter = np.array([]) ranks = [] if use_cuda: total_data = total_data.cuda() latest_time = test_times[0] for i in range(len(test_data)): batch_data = test_data[i] s_hist = s_history_test[i] o_hist = o_history_test[i] if latest_time != batch_data[3]: ranks.append(total_ranks_filter) latest_time = batch_data[3] total_ranks_filter = np.array([]) if use_cuda: batch_data = batch_data.cuda() with torch.no_grad(): # Raw metric # ranks_filter, loss = model.evaluate(batch_data, s_hist, o_hist) # Filtered metric ranks_filter, loss = model.evaluate_filter(batch_data, s_hist, o_hist, total_data) total_ranks_filter = np.concatenate( (total_ranks_filter, ranks_filter)) total_loss += loss.item() ranks.append(total_ranks_filter) for rank in ranks: total_ranks = np.concatenate((total_ranks, rank)) mrr = np.mean(1.0 / total_ranks) mr = np.mean(total_ranks) hits = [] for hit in [1, 3, 10]: avg_count = np.mean((total_ranks <= hit)) hits.append(avg_count) print("Hits (filtered) @ {}: {:.6f}".format(hit, avg_count)) print("MRR (filtered): {:.6f}".format(mrr)) print("MR (filtered): {:.6f}".format(mr))
def train(args): # load data num_nodes, num_rels = utils.get_total_number('./data/' + args.dataset, 'stat.txt') train_data, train_times_origin = utils.load_quadruples( './data/' + args.dataset, 'train.txt') # check cuda use_cuda = args.gpu >= 0 and torch.cuda.is_available() seed = 999 np.random.seed(seed) torch.manual_seed(seed) if use_cuda: torch.cuda.set_device(args.gpu) os.makedirs('models', exist_ok=True) os.makedirs('models/' + args.dataset, exist_ok=True) if args.model == 0: model_state_file = 'models/' + args.dataset + 'attn.pth' elif args.model == 1: model_state_file = 'models/' + args.dataset + 'mean.pth' elif args.model == 2: model_state_file = 'models/' + args.dataset + 'gcn.pth' elif args.model == 3: model_state_file = 'models/' + args.dataset + '/max' + str( args.maxpool) + 'rgcn_global.pth' # model_graph_file = 'models/' + args.dataset + 'rgcn_graph.pth' model_state_file_backup = 'models/' + args.dataset + '/max' + str( args.maxpool) + 'rgcn__global_backup.pth' # model_graph_file_backup = 'models/' + args.dataset + 'rgcn_graph_backup.pth' print("start training...") model = RENet_global(num_nodes, args.n_hidden, num_rels, dropout=args.dropout, model=args.model, seq_len=args.seq_len, num_k=args.num_k, maxpool=args.maxpool) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.00001) if use_cuda: model.cuda() # train_times = torch.from_numpy(train_times) with open('./data/' + args.dataset + '/train_graphs.txt', 'rb') as f: graph_dict = pickle.load(f) true_prob_s, true_prob_o = utils.get_true_distribution( train_data, num_nodes) epoch = 0 loss_small = 10000 while True: model.train() if epoch == args.max_epochs: break epoch += 1 loss_epoch = 0 t0 = time.time() # print(graph_dict.keys()) # print(train_times_origin) train_times, true_prob_s, true_prob_o = shuffle( train_times_origin, true_prob_s, true_prob_o) for batch_data, true_s, true_o in utils.make_batch( train_times, true_prob_s, true_prob_o, args.batch_size): batch_data = torch.from_numpy(batch_data) true_s = torch.from_numpy(true_s) true_o = torch.from_numpy(true_o) if use_cuda: batch_data = batch_data.cuda() true_s = true_s.cuda() true_o = true_o.cuda() loss = model(batch_data, true_s, true_o, graph_dict) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm) # clip gradients optimizer.step() optimizer.zero_grad() loss_epoch += loss.item() t3 = time.time() model.global_emb = model.get_global_emb(train_times_origin, graph_dict) print("Epoch {:04d} | Loss {:.4f} | time {:.4f}".format( epoch, loss_epoch / (len(train_times) / args.batch_size), t3 - t0)) if loss_epoch < loss_small: loss_small = loss_epoch if args.model == 3: torch.save( { 'state_dict': model.state_dict(), 'global_emb': model.global_emb }, model_state_file) # with open(model_graph_file, 'wb') as fp: # pickle.dump(model.graph_dict, fp) else: torch.save( { 'state_dict': model.state_dict(), 'epoch': epoch, 's_hist': model.s_hist_test, 's_cache': model.s_his_cache, 'o_hist': model.o_hist_test, 'o_cache': model.o_his_cache, 'latest_time': model.latest_time }, model_state_file) print("training done")
def train(args): # load data num_nodes, num_rels, num_att = utils.get_total_number( args.dataset_path, 'stat.txt') train_data, train_times = utils.load_hexaruples(args.dataset_path, 'train.txt') # check cuda use_cuda = args.gpu >= 0 and torch.cuda.is_available() seed = 999 np.random.seed(seed) torch.manual_seed(seed) if use_cuda: torch.cuda.set_device(args.gpu) model_dir = 'models/' + args.dataset + '/{}-{}-{}-{}'.format( args.dropout, args.n_hidden, args.gamma, args.num_k) os.makedirs('models', exist_ok=True) os.makedirs('models/' + args.dataset, exist_ok=True) os.makedirs(model_dir, exist_ok=True) print("start training...") model = DArtNet(num_nodes, args.n_hidden, num_rels, num_att, dropout=args.dropout, model=args.model, seq_len=args.seq_len, num_k=args.num_k, gamma=args.gamma) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.00001) if use_cuda: model.cuda() train_sub_entity = '/train_entity_s_history_data.txt' train_sub_rel = '/train_rel_s_history_data.txt' train_sub_att = '/train_att_s_history_data.txt' train_sub_self_att = '/train_self_att_s_history_data.txt' train_ob_entity = '/train_entity_o_history_data.txt' train_ob_rel = '/train_rel_o_history_data.txt' train_ob_att = '/train_att_o_history_data.txt' train_ob_self_att = '/train_self_att_o_history_data.txt' with open(args.dataset_path + train_sub_entity, 'rb') as f: entity_s_history_data_train = pickle.load(f) with open(args.dataset_path + train_sub_rel, 'rb') as f: rel_s_history_data_train = pickle.load(f) with open(args.dataset_path + train_sub_att, 'rb') as f: att_s_history_data_train = pickle.load(f) with open(args.dataset_path + train_sub_self_att, 'rb') as f: self_att_s_history_data_train = pickle.load(f) with open(args.dataset_path + train_ob_entity, 'rb') as f: entity_o_history_data_train = pickle.load(f) with open(args.dataset_path + train_ob_rel, 'rb') as f: rel_o_history_data_train = pickle.load(f) with open(args.dataset_path + train_ob_att, 'rb') as f: att_o_history_data_train = pickle.load(f) with open(args.dataset_path + train_ob_self_att, 'rb') as f: self_att_o_history_data_train = pickle.load(f) entity_s_history_train = entity_s_history_data_train rel_s_history_train = rel_s_history_data_train att_s_history_train = att_s_history_data_train self_att_s_history_train = self_att_s_history_data_train entity_o_history_train = entity_o_history_data_train rel_o_history_train = rel_o_history_data_train att_o_history_train = att_o_history_data_train self_att_o_history_train = self_att_o_history_data_train epoch = 0 if args.retrain != 0: try: checkpoint = torch.load(model_dir + '/checkpoint.pth', map_location=f"cuda:{args.gpu}") model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] model.latest_time = checkpoint['latest_time'] model.to(torch.device(f"cuda:{args.gpu}")) except FileNotFoundError as _: try: e = sorted([ int(file[6:-4]) for file in os.listdir(model_dir) if file[-4:] == '.pth' ], reverse=True)[0] checkpoint = torch.load(model_dir + '/epoch-{}.pth'.format(e), map_location=f"cuda:{args.gpu}") model.load_state_dict(checkpoint['state_dict']) epoch = checkpoint['epoch'] model.latest_time = checkpoint['latest_time'] model.to(torch.device(f"cuda:{args.gpu}")) except Exception as _: print('no model found') print('training from scratch') while True: model.train() if epoch == args.max_epochs: break epoch += 1 loss_epoch = 0 loss_att_sub_epoch = 0 # loss_att_ob_epoch = 0 t0 = time.time() train_data, entity_s_history_train, rel_s_history_train, entity_o_history_train, rel_o_history_train, att_s_history_train, self_att_s_history_train, att_o_history_train, self_att_o_history_train = shuffle( train_data, entity_s_history_train, rel_s_history_train, entity_o_history_train, rel_o_history_train, att_s_history_train, self_att_s_history_train, att_o_history_train, self_att_o_history_train) iteration = 0 for batch_data, s_hist, rel_s_hist, o_hist, rel_o_hist, att_s_hist, self_att_s_hist, att_o_hist, self_att_o_hist in utils.make_batch3( train_data, entity_s_history_train, rel_s_history_train, entity_o_history_train, rel_o_history_train, att_s_history_train, self_att_s_history_train, att_o_history_train, self_att_o_history_train, args.batch_size): iteration += 1 print(f'iteration {iteration}', end='\r') batch_data = torch.from_numpy(batch_data) if use_cuda: batch_data = batch_data.cuda() loss, loss_att_sub = model.get_loss(batch_data, s_hist, rel_s_hist, att_s_hist, self_att_s_hist, o_hist, rel_o_hist, att_o_hist, self_att_o_hist) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm) # clip gradients optimizer.step() optimizer.zero_grad() loss_epoch += loss.item() loss_att_sub_epoch += loss_att_sub.item() # loss_att_ob_epoch += loss_att_ob.item() t3 = time.time() print( "Epoch {:04d} | Loss {:.4f} | Loss_att_sub {:.4f} | time {:.4f} ". format(epoch, loss_epoch / (len(train_data) / args.batch_size), loss_att_sub_epoch / (len(train_data) / args.batch_size), t3 - t0)) torch.save( { 'state_dict': model.state_dict(), 'epoch': epoch, 'latest_time': model.latest_time, }, model_dir + '/epoch-{}.pth'.format(epoch)) torch.save( { 'state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch, 'latest_time': model.latest_time, }, model_dir + '/checkpoint.pth') print("training done")