def test(args): # load data num_nodes, num_rels = utils.get_total_number(args.dataset_path, 'stat.txt') test_data, test_times = utils.load_hexaruples(args.dataset_path, 'test.txt') total_data, total_times = utils.load_hexaruples(args.dataset_path, 'train.txt', 'valid.txt', 'test.txt') model_dir = 'models/' + args.dataset + '/{}-{}-{}-{}'.format( args.dropout, args.n_hidden, args.gamma, args.num_k) model_state_file = model_dir + '/epoch-{}.pth'.format(args.epoch) # check cuda use_cuda = args.gpu >= 0 and torch.cuda.is_available() if use_cuda: torch.cuda.set_device(args.gpu) torch.cuda.manual_seed_all(999) model = DArtNet(num_nodes, args.n_hidden, num_rels, model=args.model, seq_len=args.seq_len, num_k=args.num_k, gamma=args.gamma) if use_cuda: model.cuda() test_sub_entity = '/test_entity_s_history_data.txt' test_sub_rel = '/test_rel_s_history_data.txt' test_sub_att = '/test_att_s_history_data.txt' test_sub_self_att = '/test_self_att_s_history_data.txt' test_ob_entity = '/test_entity_o_history_data.txt' test_ob_rel = '/test_rel_o_history_data.txt' test_ob_att = '/test_att_o_history_data.txt' test_ob_self_att = '/test_self_att_o_history_data.txt' with open(args.dataset_path + test_sub_entity, 'rb') as f: entity_s_history_data_test = pickle.load(f) with open(args.dataset_path + test_sub_rel, 'rb') as f: rel_s_history_data_test = pickle.load(f) with open(args.dataset_path + test_sub_att, 'rb') as f: att_s_history_data_test = pickle.load(f) with open(args.dataset_path + test_sub_self_att, 'rb') as f: self_att_s_history_data_test = pickle.load(f) with open(args.dataset_path + test_ob_entity, 'rb') as f: entity_o_history_data_test = pickle.load(f) with open(args.dataset_path + test_ob_rel, 'rb') as f: rel_o_history_data_test = pickle.load(f) with open(args.dataset_path + test_ob_att, 'rb') as f: att_o_history_data_test = pickle.load(f) with open(args.dataset_path + test_ob_self_att, 'rb') as f: self_att_o_history_data_test = pickle.load(f) print(f'\nstart testing model file : {model_state_file}') checkpoint = torch.load(model_state_file, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['state_dict']) model.init_history() model.latest_time = checkpoint['latest_time'] print("Using epoch: {}".format(checkpoint['epoch'])) total_data = torch.from_numpy(total_data) test_data = torch.from_numpy(test_data) model.eval() total_att_sub_loss = 0 total_ranks = np.array([]) total_ranks_filter = np.array([]) ranks = [] with torch.no_grad(): latest_time = test_times[0] j = 0 while j < len(test_data): k = j while k < len(test_data): if test_data[k][-1] == test_data[j][-1]: k += 1 else: break start = j while start < k: end = min(k, start + args.batch_size) batch_data = test_data[start:end].clone() s_hist = entity_s_history_data_test[start:end].copy() o_hist = entity_o_history_data_test[start:end].copy() rel_s_hist = rel_s_history_data_test[start:end].copy() rel_o_hist = rel_o_history_data_test[start:end].copy() att_s_hist = att_s_history_data_test[start:end].copy() att_o_hist = att_o_history_data_test[start:end].copy() self_att_s_hist = self_att_s_history_data_test[start:end].copy( ) self_att_o_hist = self_att_o_history_data_test[start:end].copy( ) if use_cuda: batch_data = batch_data.cuda() loss_sub = model.predict(batch_data, s_hist, rel_s_hist, att_s_hist, self_att_s_hist, o_hist, rel_o_hist, att_o_hist, self_att_o_hist) total_att_sub_loss += (loss_sub.item() * (end - start + 1)) start += args.batch_size for i in range(j, k): batch_data = test_data[i].clone() s_hist = entity_s_history_data_test[i].copy() o_hist = entity_o_history_data_test[i].copy() rel_s_hist = rel_s_history_data_test[i].copy() rel_o_hist = rel_o_history_data_test[i].copy() att_s_hist = att_s_history_data_test[i].copy() att_o_hist = att_o_history_data_test[i].copy() self_att_s_hist = self_att_s_history_data_test[i].copy() self_att_o_hist = self_att_o_history_data_test[i].copy() if use_cuda: batch_data = batch_data.cuda() ranks_pred = model.evaluate_filter(batch_data, s_hist, rel_s_hist, att_s_hist, self_att_s_hist, o_hist, rel_o_hist, att_o_hist, self_att_o_hist, total_data) total_ranks_filter = np.concatenate( (total_ranks_filter, ranks_pred)) j = k ranks.append(total_ranks_filter) for rank in ranks: total_ranks = np.concatenate((total_ranks, rank)) mrr = np.mean(1.0 / total_ranks) mr = np.mean(total_ranks) hits = [] for hit in [1, 3, 10]: avg_count = np.mean((total_ranks <= hit)) hits.append(avg_count) print("Hits (filtered) @ {}: {:.6f}".format(hit, avg_count)) print("MRR (filtered): {:.6f}".format(mrr)) print("MR (filtered): {:.6f}".format(mr)) print("test att sub Loss: {:.6f}".format(total_att_sub_loss / (len(test_data)))) result_epoch = result(epoch=args.epoch, MRR=100 * mrr, sub_att_loss=total_att_sub_loss / len(test_data), MR=mr, Hits1=100 * hits[0], Hits3=100 * hits[1], Hits10=100 * hits[2]) result_dict[args.epoch] = result_epoch
def train(args): # load data num_nodes, num_rels, num_att = utils.get_total_number( args.dataset_path, 'stat.txt') train_data, train_times = utils.load_hexaruples(args.dataset_path, 'train.txt') # check cuda use_cuda = args.gpu >= 0 and torch.cuda.is_available() seed = 999 np.random.seed(seed) torch.manual_seed(seed) if use_cuda: torch.cuda.set_device(args.gpu) model_dir = 'models/' + args.dataset + '/{}-{}-{}-{}'.format( args.dropout, args.n_hidden, args.gamma, args.num_k) os.makedirs('models', exist_ok=True) os.makedirs('models/' + args.dataset, exist_ok=True) os.makedirs(model_dir, exist_ok=True) print("start training...") model = DArtNet(num_nodes, args.n_hidden, num_rels, num_att, dropout=args.dropout, model=args.model, seq_len=args.seq_len, num_k=args.num_k, gamma=args.gamma) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.00001) if use_cuda: model.cuda() train_sub_entity = '/train_entity_s_history_data.txt' train_sub_rel = '/train_rel_s_history_data.txt' train_sub_att = '/train_att_s_history_data.txt' train_sub_self_att = '/train_self_att_s_history_data.txt' train_ob_entity = '/train_entity_o_history_data.txt' train_ob_rel = '/train_rel_o_history_data.txt' train_ob_att = '/train_att_o_history_data.txt' train_ob_self_att = '/train_self_att_o_history_data.txt' with open(args.dataset_path + train_sub_entity, 'rb') as f: entity_s_history_data_train = pickle.load(f) with open(args.dataset_path + train_sub_rel, 'rb') as f: rel_s_history_data_train = pickle.load(f) with open(args.dataset_path + train_sub_att, 'rb') as f: att_s_history_data_train = pickle.load(f) with open(args.dataset_path + train_sub_self_att, 'rb') as f: self_att_s_history_data_train = pickle.load(f) with open(args.dataset_path + train_ob_entity, 'rb') as f: entity_o_history_data_train = pickle.load(f) with open(args.dataset_path + train_ob_rel, 'rb') as f: rel_o_history_data_train = pickle.load(f) with open(args.dataset_path + train_ob_att, 'rb') as f: att_o_history_data_train = pickle.load(f) with open(args.dataset_path + train_ob_self_att, 'rb') as f: self_att_o_history_data_train = pickle.load(f) entity_s_history_train = entity_s_history_data_train rel_s_history_train = rel_s_history_data_train att_s_history_train = att_s_history_data_train self_att_s_history_train = self_att_s_history_data_train entity_o_history_train = entity_o_history_data_train rel_o_history_train = rel_o_history_data_train att_o_history_train = att_o_history_data_train self_att_o_history_train = self_att_o_history_data_train epoch = 0 if args.retrain != 0: try: checkpoint = torch.load(model_dir + '/checkpoint.pth', map_location=f"cuda:{args.gpu}") model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] model.latest_time = checkpoint['latest_time'] model.to(torch.device(f"cuda:{args.gpu}")) except FileNotFoundError as _: try: e = sorted([ int(file[6:-4]) for file in os.listdir(model_dir) if file[-4:] == '.pth' ], reverse=True)[0] checkpoint = torch.load(model_dir + '/epoch-{}.pth'.format(e), map_location=f"cuda:{args.gpu}") model.load_state_dict(checkpoint['state_dict']) epoch = checkpoint['epoch'] model.latest_time = checkpoint['latest_time'] model.to(torch.device(f"cuda:{args.gpu}")) except Exception as _: print('no model found') print('training from scratch') while True: model.train() if epoch == args.max_epochs: break epoch += 1 loss_epoch = 0 loss_att_sub_epoch = 0 # loss_att_ob_epoch = 0 t0 = time.time() train_data, entity_s_history_train, rel_s_history_train, entity_o_history_train, rel_o_history_train, att_s_history_train, self_att_s_history_train, att_o_history_train, self_att_o_history_train = shuffle( train_data, entity_s_history_train, rel_s_history_train, entity_o_history_train, rel_o_history_train, att_s_history_train, self_att_s_history_train, att_o_history_train, self_att_o_history_train) iteration = 0 for batch_data, s_hist, rel_s_hist, o_hist, rel_o_hist, att_s_hist, self_att_s_hist, att_o_hist, self_att_o_hist in utils.make_batch3( train_data, entity_s_history_train, rel_s_history_train, entity_o_history_train, rel_o_history_train, att_s_history_train, self_att_s_history_train, att_o_history_train, self_att_o_history_train, args.batch_size): iteration += 1 print(f'iteration {iteration}', end='\r') batch_data = torch.from_numpy(batch_data) if use_cuda: batch_data = batch_data.cuda() loss, loss_att_sub = model.get_loss(batch_data, s_hist, rel_s_hist, att_s_hist, self_att_s_hist, o_hist, rel_o_hist, att_o_hist, self_att_o_hist) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm) # clip gradients optimizer.step() optimizer.zero_grad() loss_epoch += loss.item() loss_att_sub_epoch += loss_att_sub.item() # loss_att_ob_epoch += loss_att_ob.item() t3 = time.time() print( "Epoch {:04d} | Loss {:.4f} | Loss_att_sub {:.4f} | time {:.4f} ". format(epoch, loss_epoch / (len(train_data) / args.batch_size), loss_att_sub_epoch / (len(train_data) / args.batch_size), t3 - t0)) torch.save( { 'state_dict': model.state_dict(), 'epoch': epoch, 'latest_time': model.latest_time, }, model_dir + '/epoch-{}.pth'.format(epoch)) torch.save( { 'state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch, 'latest_time': model.latest_time, }, model_dir + '/checkpoint.pth') print("training done")