Ejemplo n.º 1
0
def test(args):
    # load data
    num_nodes, num_rels = utils.get_total_number(args.dataset_path, 'stat.txt')
    test_data, test_times = utils.load_hexaruples(args.dataset_path,
                                                  'test.txt')
    total_data, total_times = utils.load_hexaruples(args.dataset_path,
                                                    'train.txt', 'valid.txt',
                                                    'test.txt')

    model_dir = 'models/' + args.dataset + '/{}-{}-{}-{}'.format(
        args.dropout, args.n_hidden, args.gamma, args.num_k)
    model_state_file = model_dir + '/epoch-{}.pth'.format(args.epoch)

    # check cuda
    use_cuda = args.gpu >= 0 and torch.cuda.is_available()
    if use_cuda:
        torch.cuda.set_device(args.gpu)
        torch.cuda.manual_seed_all(999)

    model = DArtNet(num_nodes,
                    args.n_hidden,
                    num_rels,
                    model=args.model,
                    seq_len=args.seq_len,
                    num_k=args.num_k,
                    gamma=args.gamma)

    if use_cuda:
        model.cuda()

    test_sub_entity = '/test_entity_s_history_data.txt'
    test_sub_rel = '/test_rel_s_history_data.txt'
    test_sub_att = '/test_att_s_history_data.txt'
    test_sub_self_att = '/test_self_att_s_history_data.txt'

    test_ob_entity = '/test_entity_o_history_data.txt'
    test_ob_rel = '/test_rel_o_history_data.txt'
    test_ob_att = '/test_att_o_history_data.txt'
    test_ob_self_att = '/test_self_att_o_history_data.txt'

    with open(args.dataset_path + test_sub_entity, 'rb') as f:
        entity_s_history_data_test = pickle.load(f)
    with open(args.dataset_path + test_sub_rel, 'rb') as f:
        rel_s_history_data_test = pickle.load(f)
    with open(args.dataset_path + test_sub_att, 'rb') as f:
        att_s_history_data_test = pickle.load(f)
    with open(args.dataset_path + test_sub_self_att, 'rb') as f:
        self_att_s_history_data_test = pickle.load(f)

    with open(args.dataset_path + test_ob_entity, 'rb') as f:
        entity_o_history_data_test = pickle.load(f)
    with open(args.dataset_path + test_ob_rel, 'rb') as f:
        rel_o_history_data_test = pickle.load(f)
    with open(args.dataset_path + test_ob_att, 'rb') as f:
        att_o_history_data_test = pickle.load(f)
    with open(args.dataset_path + test_ob_self_att, 'rb') as f:
        self_att_o_history_data_test = pickle.load(f)

    print(f'\nstart testing model file : {model_state_file}')

    checkpoint = torch.load(model_state_file,
                            map_location=lambda storage, loc: storage)

    model.load_state_dict(checkpoint['state_dict'])

    model.init_history()

    model.latest_time = checkpoint['latest_time']

    print("Using epoch: {}".format(checkpoint['epoch']))

    total_data = torch.from_numpy(total_data)
    test_data = torch.from_numpy(test_data)

    model.eval()
    total_att_sub_loss = 0
    total_ranks = np.array([])
    total_ranks_filter = np.array([])
    ranks = []

    with torch.no_grad():
        latest_time = test_times[0]
        j = 0
        while j < len(test_data):
            k = j
            while k < len(test_data):
                if test_data[k][-1] == test_data[j][-1]:
                    k += 1
                else:
                    break

            start = j
            while start < k:
                end = min(k, start + args.batch_size)

                batch_data = test_data[start:end].clone()
                s_hist = entity_s_history_data_test[start:end].copy()
                o_hist = entity_o_history_data_test[start:end].copy()
                rel_s_hist = rel_s_history_data_test[start:end].copy()
                rel_o_hist = rel_o_history_data_test[start:end].copy()
                att_s_hist = att_s_history_data_test[start:end].copy()
                att_o_hist = att_o_history_data_test[start:end].copy()
                self_att_s_hist = self_att_s_history_data_test[start:end].copy(
                )
                self_att_o_hist = self_att_o_history_data_test[start:end].copy(
                )

                if use_cuda:
                    batch_data = batch_data.cuda()

                loss_sub = model.predict(batch_data, s_hist, rel_s_hist,
                                         att_s_hist, self_att_s_hist, o_hist,
                                         rel_o_hist, att_o_hist,
                                         self_att_o_hist)

                total_att_sub_loss += (loss_sub.item() * (end - start + 1))

                start += args.batch_size

            for i in range(j, k):
                batch_data = test_data[i].clone()
                s_hist = entity_s_history_data_test[i].copy()
                o_hist = entity_o_history_data_test[i].copy()
                rel_s_hist = rel_s_history_data_test[i].copy()
                rel_o_hist = rel_o_history_data_test[i].copy()
                att_s_hist = att_s_history_data_test[i].copy()
                att_o_hist = att_o_history_data_test[i].copy()
                self_att_s_hist = self_att_s_history_data_test[i].copy()
                self_att_o_hist = self_att_o_history_data_test[i].copy()

                if use_cuda:
                    batch_data = batch_data.cuda()

                ranks_pred = model.evaluate_filter(batch_data, s_hist,
                                                   rel_s_hist, att_s_hist,
                                                   self_att_s_hist, o_hist,
                                                   rel_o_hist, att_o_hist,
                                                   self_att_o_hist, total_data)

                total_ranks_filter = np.concatenate(
                    (total_ranks_filter, ranks_pred))

            j = k

    ranks.append(total_ranks_filter)

    for rank in ranks:
        total_ranks = np.concatenate((total_ranks, rank))
    mrr = np.mean(1.0 / total_ranks)
    mr = np.mean(total_ranks)
    hits = []

    for hit in [1, 3, 10]:
        avg_count = np.mean((total_ranks <= hit))
        hits.append(avg_count)

        print("Hits (filtered) @ {}: {:.6f}".format(hit, avg_count))
    print("MRR (filtered): {:.6f}".format(mrr))
    print("MR (filtered): {:.6f}".format(mr))
    print("test att sub Loss: {:.6f}".format(total_att_sub_loss /
                                             (len(test_data))))

    result_epoch = result(epoch=args.epoch,
                          MRR=100 * mrr,
                          sub_att_loss=total_att_sub_loss / len(test_data),
                          MR=mr,
                          Hits1=100 * hits[0],
                          Hits3=100 * hits[1],
                          Hits10=100 * hits[2])

    result_dict[args.epoch] = result_epoch
Ejemplo n.º 2
0
def train(args):
    # load data
    num_nodes, num_rels = utils.get_total_number('./data/' + args.dataset,
                                                 'stat.txt')
    train_data, train_times = utils.load_quadruples('./data/' + args.dataset,
                                                    'train.txt')
    valid_data, valid_times = utils.load_quadruples('./data/' + args.dataset,
                                                    'valid.txt')
    total_data, total_times = utils.load_quadruples('./data/' + args.dataset,
                                                    'train.txt', 'valid.txt',
                                                    'test.txt')

    # check cuda
    use_cuda = args.gpu >= 0 and torch.cuda.is_available()
    if use_cuda:
        torch.cuda.set_device(args.gpu)
        torch.cuda.manual_seed_all(999)

    os.makedirs('models', exist_ok=True)
    if args.model == 0:
        model_state_file = 'models/' + args.dataset + 'attn.pth'
    elif args.model == 1:
        model_state_file = 'models/' + args.dataset + 'mean.pth'
    elif args.model == 2:
        model_state_file = 'models/' + args.dataset + 'gcn.pth'

    print("start training...")
    model = RENet(num_nodes,
                  args.n_hidden,
                  num_rels,
                  dropout=args.dropout,
                  model=args.model,
                  seq_len=args.seq_len)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=0.00001)

    if use_cuda:
        model.cuda()
    with open('./data/' + args.dataset + '/train_history_sub.txt', 'rb') as f:
        s_history = pickle.load(f)
    with open('./data/' + args.dataset + '/train_history_ob.txt', 'rb') as f:
        o_history = pickle.load(f)

    with open('./data/' + args.dataset + '/dev_history_sub.txt', 'rb') as f:
        s_history_valid = pickle.load(f)
    with open('./data/' + args.dataset + '/dev_history_ob.txt', 'rb') as f:
        o_history_valid = pickle.load(f)
    valid_data = torch.from_numpy(valid_data)

    epoch = 0
    best_mrr = 0
    while True:
        model.train()
        if epoch == args.max_epochs:
            break
        epoch += 1
        loss_epoch = 0
        t0 = time.time()

        train_data, s_history, o_history = shuffle(train_data, s_history,
                                                   o_history)
        i = 0
        for batch_data, s_hist, o_hist in utils.make_batch(
                train_data, s_history, o_history, args.batch_size):
            batch_data = torch.from_numpy(batch_data)
            if use_cuda:
                batch_data = batch_data.cuda()

            loss = model.get_loss(batch_data, s_hist, o_hist)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           args.grad_norm)  # clip gradients
            optimizer.step()
            optimizer.zero_grad()
            loss_epoch += loss.item()
            i += 1

        t3 = time.time()

        print("Epoch {:04d} | Loss {:.4f} | time {:.4f}".format(
            epoch, loss_epoch / (len(train_data) / args.batch_size), t3 - t0))

        if epoch % 1 == 0:
            model.eval()
            total_loss = 0
            total_ranks = np.array([])
            model.init_history()
            model.latest_time = valid_data[0][3]

            for i in range(len(valid_data)):
                batch_data = valid_data[i]
                s_hist = s_history_valid[i]
                o_hist = o_history_valid[i]

                if use_cuda:
                    batch_data = batch_data.cuda()

                with torch.no_grad():
                    ranks, loss = model.evaluate(batch_data, s_hist, o_hist)
                    total_ranks = np.concatenate((total_ranks, ranks))
                    total_loss += loss.item()

            mrr = np.mean(1.0 / total_ranks)
            mr = np.mean(total_ranks)
            hits = []
            for hit in [1, 3, 10]:
                avg_count = np.mean((total_ranks <= hit))
                hits.append(avg_count)
                print("valid Hits (raw) @ {}: {:.6f}".format(hit, avg_count))
            print("valid MRR (raw): {:.6f}".format(mrr))
            print("valid MR (raw): {:.6f}".format(mr))
            print("valid Loss: {:.6f}".format(total_loss / (len(valid_data))))

            if mrr > best_mrr:
                best_mrr = mrr
                torch.save(
                    {
                        'state_dict': model.state_dict(),
                        'epoch': epoch,
                        's_hist': model.s_hist_test,
                        's_cache': model.s_his_cache,
                        'o_hist': model.o_hist_test,
                        'o_cache': model.o_his_cache,
                        'latest_time': model.latest_time
                    }, model_state_file)

    print("training done")
Ejemplo n.º 3
0
def test(args):
    # load data
    num_nodes, num_rels = utils.get_total_number('./data/' + args.dataset,
                                                 'stat.txt')
    if args.dataset == 'icews_know':
        train_data, train_times = utils.load_quadruples(
            './data/' + args.dataset, 'train.txt')
        valid_data, valid_times = utils.load_quadruples(
            './data/' + args.dataset, 'test.txt')
        test_data, test_times = utils.load_quadruples('./data/' + args.dataset,
                                                      'test.txt')
        total_data, total_times = utils.load_quadruples(
            './data/' + args.dataset, 'train.txt', 'test.txt')
    else:
        train_data, train_times = utils.load_quadruples(
            './data/' + args.dataset, 'train.txt')
        valid_data, valid_times = utils.load_quadruples(
            './data/' + args.dataset, 'valid.txt')
        test_data, test_times = utils.load_quadruples('./data/' + args.dataset,
                                                      'test.txt')
        total_data, total_times = utils.load_quadruples(
            './data/' + args.dataset, 'train.txt', 'valid.txt', 'test.txt')
    # check cuda
    use_cuda = args.gpu >= 0 and torch.cuda.is_available()
    if use_cuda:
        torch.cuda.set_device(args.gpu)
        torch.cuda.manual_seed_all(999)

    model_state_file = 'models/' + args.dataset + '/rgcn.pth'
    model_graph_file = 'models/' + args.dataset + '/rgcn_graph.pth'
    model_state_global_file2 = 'models/' + args.dataset + '/max' + str(
        args.maxpool) + 'rgcn_global2.pth'

    model = RENet(num_nodes,
                  args.n_hidden,
                  num_rels,
                  model=args.model,
                  seq_len=args.seq_len,
                  num_k=args.num_k)
    global_model = RENet_global(num_nodes,
                                args.n_hidden,
                                num_rels,
                                model=args.model,
                                seq_len=args.seq_len,
                                num_k=args.num_k,
                                maxpool=args.maxpool)

    if use_cuda:
        model.cuda()
        global_model.cuda()

    with open('data/' + args.dataset + '/test_history_sub.txt', 'rb') as f:
        s_history_test_data = pickle.load(f)
    with open('data/' + args.dataset + '/test_history_ob.txt', 'rb') as f:
        o_history_test_data = pickle.load(f)

    s_history_test = s_history_test_data[0]
    s_history_test_t = s_history_test_data[1]
    o_history_test = o_history_test_data[0]
    o_history_test_t = o_history_test_data[1]

    print("\nstart testing:")

    checkpoint = torch.load(model_state_file,
                            map_location=lambda storage, loc: storage)
    model.load_state_dict(checkpoint['state_dict'])
    model.s_hist_test = checkpoint['s_hist']
    model.s_his_cache = checkpoint['s_cache']
    model.o_hist_test = checkpoint['o_hist']
    model.o_his_cache = checkpoint['o_cache']
    model.latest_time = checkpoint['latest_time']
    if args.dataset == "icews_know":
        model.latest_time = torch.LongTensor([4344])[0]
    model.global_emb = checkpoint['global_emb']
    model.s_hist_test_t = checkpoint['s_hist_t']
    model.s_his_cache_t = checkpoint['s_cache_t']
    model.o_hist_test_t = checkpoint['o_hist_t']
    model.o_his_cache_t = checkpoint['o_cache_t']
    with open(model_graph_file, 'rb') as f:
        model.graph_dict = pickle.load(f)

    checkpoint_global = torch.load(model_state_global_file2,
                                   map_location=lambda storage, loc: storage)
    global_model.load_state_dict(checkpoint_global['state_dict'])

    print("Using best epoch: {}".format(checkpoint['epoch']))

    total_data = torch.from_numpy(total_data)
    test_data = torch.from_numpy(test_data)

    model.eval()
    global_model.eval()
    total_loss = 0
    total_ranks = np.array([])
    total_ranks_filter = np.array([])
    ranks = []
    for ee in range(num_nodes):
        while len(model.s_hist_test[ee]) > args.seq_len:
            model.s_hist_test[ee].pop(0)
            model.s_hist_test_t[ee].pop(0)
        while len(model.o_hist_test[ee]) > args.seq_len:
            model.o_hist_test[ee].pop(0)
            model.o_hist_test_t[ee].pop(0)

    if use_cuda:
        total_data = total_data.cuda()

    latest_time = test_times[0]
    for i in range(len(test_data)):
        batch_data = test_data[i]
        s_hist = s_history_test[i]
        o_hist = o_history_test[i]
        if args.model == 3:
            s_hist_t = s_history_test_t[i]
            o_hist_t = o_history_test_t[i]
        if latest_time != batch_data[3]:
            ranks.append(total_ranks_filter)
            latest_time = batch_data[3]
            total_ranks_filter = np.array([])

        if use_cuda:
            batch_data = batch_data.cuda()

        with torch.no_grad():
            # Filtered metric
            if args.raw:
                ranks_filter, loss = model.evaluate(batch_data,
                                                    (s_hist, s_hist_t),
                                                    (o_hist, o_hist_t),
                                                    global_model)
            else:
                ranks_filter, loss = model.evaluate_filter(
                    batch_data, (s_hist, s_hist_t), (o_hist, o_hist_t),
                    global_model, total_data)

            total_ranks_filter = np.concatenate(
                (total_ranks_filter, ranks_filter))
            total_loss += loss.item()

    ranks.append(total_ranks_filter)

    for rank in ranks:
        total_ranks = np.concatenate((total_ranks, rank))
    mrr = np.mean(1.0 / total_ranks)
    mr = np.mean(total_ranks)
    hits = []

    for hit in [1, 3, 10]:
        avg_count = np.mean((total_ranks <= hit))
        hits.append(avg_count)
        print("Hits (filtered) @ {}: {:.6f}".format(hit, avg_count))
    print("MRR (filtered): {:.6f}".format(mrr))
    print("MR (filtered): {:.6f}".format(mr))
Ejemplo n.º 4
0
def train(args):
    # load data
    num_nodes, num_rels = utils.get_total_number('./data/' + args.dataset, 'stat.txt')
    if args.dataset == 'icews_know':
        train_data, train_times = utils.load_quadruples('./data/' + args.dataset, 'train.txt')
        valid_data, valid_times = utils.load_quadruples('./data/' + args.dataset, 'test.txt')
        test_data, test_times = utils.load_quadruples('./data/' + args.dataset, 'test.txt')
        total_data, total_times = utils.load_quadruples('./data/' + args.dataset, 'train.txt', 'test.txt')
    else:
        train_data, train_times = utils.load_quadruples('./data/' + args.dataset, 'train.txt')
        valid_data, valid_times = utils.load_quadruples('./data/' + args.dataset, 'valid.txt')
        test_data, test_times = utils.load_quadruples('./data/' + args.dataset, 'test.txt')
        total_data, total_times = utils.load_quadruples('./data/' + args.dataset, 'train.txt', 'valid.txt','test.txt')

    # check cuda
    use_cuda = args.gpu >= 0 and torch.cuda.is_available()
    seed = 999
    np.random.seed(seed)
    torch.manual_seed(seed)
    if use_cuda:
        torch.cuda.set_device(args.gpu)

    os.makedirs('models', exist_ok=True)
    os.makedirs('models/'+ args.dataset, exist_ok=True)

    model_state_file = 'models/' + args.dataset + '/rgcn.pth'
    model_graph_file = 'models/' + args.dataset + '/rgcn_graph.pth'
    model_state_global_file2 = 'models/' + args.dataset + '/max' + str(args.maxpool) + 'rgcn_global2.pth'
    model_state_global_file = 'models/' + args.dataset + '/max' + str(args.maxpool) + 'rgcn_global.pth'
    model_state_file_backup = 'models/' + args.dataset + '/rgcn_backup.pth'

    print("start training...")
    model = RENet(num_nodes,
                    args.n_hidden,
                    num_rels,
                    dropout=args.dropout,
                    model=args.model,
                    seq_len=args.seq_len,
                    num_k=args.num_k)

    global_model = RENet_global(num_nodes,
                         args.n_hidden,
                         num_rels,
                         dropout=args.dropout,
                         model=args.model,
                         seq_len=args.seq_len,
                         num_k=args.num_k, maxpool=args.maxpool)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.00001)

    checkpoint_global = torch.load(model_state_global_file, map_location=lambda storage, loc: storage)
    global_model.load_state_dict(checkpoint_global['state_dict'])
    global_emb = checkpoint_global['global_emb']
    model.global_emb = global_emb
    if use_cuda:
        model.cuda()
        global_model.cuda()
    train_sub = '/train_history_sub.txt'
    train_ob = '/train_history_ob.txt'
    if args.dataset == 'icews_know':
        valid_sub = '/test_history_sub.txt'
        valid_ob = '/test_history_ob.txt'
    else:
        valid_sub = '/dev_history_sub.txt'
        valid_ob = '/dev_history_ob.txt'
    with open('./data/' + args.dataset+'/train_graphs.txt', 'rb') as f:
        graph_dict = pickle.load(f)
    model.graph_dict = graph_dict

    with open('data/' + args.dataset+'/test_history_sub.txt', 'rb') as f:
        s_history_test_data = pickle.load(f)
    with open('data/' + args.dataset+'/test_history_ob.txt', 'rb') as f:
        o_history_test_data = pickle.load(f)

    s_history_test = s_history_test_data[0]
    s_history_test_t = s_history_test_data[1]
    o_history_test = o_history_test_data[0]
    o_history_test_t = o_history_test_data[1]

    with open('./data/' + args.dataset+train_sub, 'rb') as f:
        s_history_data = pickle.load(f)
    with open('./data/' + args.dataset+train_ob, 'rb') as f:
        o_history_data = pickle.load(f)

    with open('./data/' + args.dataset+valid_sub, 'rb') as f:
        s_history_valid_data = pickle.load(f)
    with open('./data/' + args.dataset+valid_ob, 'rb') as f:
        o_history_valid_data = pickle.load(f)
    valid_data = torch.from_numpy(valid_data)

    s_history = s_history_data[0]
    s_history_t = s_history_data[1]
    o_history = o_history_data[0]
    o_history_t = o_history_data[1]
    s_history_valid = s_history_valid_data[0]
    s_history_valid_t = s_history_valid_data[1]
    o_history_valid = o_history_valid_data[0]
    o_history_valid_t = o_history_valid_data[1]

    total_data = torch.from_numpy(total_data)
    if use_cuda:
        total_data = total_data.cuda()


    epoch = 0
    best_mrr = 0
    while True:
        print('training starting')
        model.train()
        if epoch == args.max_epochs:
            break
        epoch += 1
        loss_epoch = 0
        t0 = time.time()
        print('training time captured')
        train_data_shuffle, s_history_shuffle, s_history_t_shuffle, o_history_shuffle, o_history_t_shuffle = shuffle(train_data, s_history, s_history_t, o_history, o_history_t)
        print('training data formatted')

        for batch_data, s_hist, s_hist_t, o_hist, o_hist_t in utils.make_batch2(train_data_shuffle, s_history_shuffle, s_history_t_shuffle, o_history_shuffle, o_history_t_shuffle, args.batch_size):
            # break
            batch_data = torch.from_numpy(batch_data).long()
            if use_cuda:
                batch_data = batch_data.cuda()
            print('batch instance preprocessing')
            loss_s = model(batch_data, (s_hist, s_hist_t), (o_hist, o_hist_t), graph_dict, subject=True)
            loss_o = model(batch_data, (s_hist, s_hist_t), (o_hist, o_hist_t), graph_dict, subject=False)
            loss = loss_s + loss_o
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm)  # clip gradients
            optimizer.step()
            optimizer.zero_grad()
            loss_epoch += loss.item()



        t3 = time.time()
        print("Epoch {:04d} | Loss {:.4f} | time {:.4f}".
              format(epoch, loss_epoch/(len(train_data)/args.batch_size), t3 - t0))

        ## VALIDATION
        if epoch % args.valid_every == 0 and epoch >= int(args.max_epochs/2):
            model.eval()
            global_model.eval()
            total_loss = 0
            total_ranks = np.array([])
            model.init_history(train_data, (s_history, s_history_t), (o_history, o_history_t), valid_data,
                           (s_history_valid, s_history_valid_t), (o_history_valid, o_history_valid_t), test_data,
                           (s_history_test, s_history_test_t), (o_history_test, o_history_test_t))
            model.latest_time = valid_data[0][3]

            for i in range(len(valid_data)):
                batch_data = valid_data[i]
                s_hist = s_history_valid[i]
                o_hist = o_history_valid[i]
                s_hist_t = s_history_valid_t[i]
                o_hist_t = o_history_valid_t[i]

                if use_cuda:
                    batch_data = batch_data.cuda()

                with torch.no_grad():
                    ranks, loss = model.evaluate_filter(batch_data, (s_hist, s_hist_t), (o_hist, o_hist_t), global_model, total_data)
                    total_ranks = np.concatenate((total_ranks, ranks))
                    total_loss += loss.item()

            mrr = np.mean(1.0 / total_ranks)
            mr = np.mean(total_ranks)
            hits = []
            for hit in [1, 3, 10]:
                avg_count = np.mean((total_ranks <= hit))
                hits.append(avg_count)
                print("valid Hits (filtered) @ {}: {:.6f}".format(hit, avg_count))
            print("valid MRR (filtered): {:.6f}".format(mrr))
            print("valid MR (filtered): {:.6f}".format(mr))
            print("valid Loss: {:.6f}".format(total_loss / (len(valid_data))))

            if mrr > best_mrr:
                best_mrr = mrr
                torch.save({'state_dict': model.state_dict(), 'epoch': epoch,
                        's_hist': model.s_hist_test, 's_cache': model.s_his_cache,
                        'o_hist': model.o_hist_test, 'o_cache': model.o_his_cache,
                        's_hist_t': model.s_hist_test_t, 's_cache_t': model.s_his_cache_t,
                        'o_hist_t': model.o_hist_test_t, 'o_cache_t': model.o_his_cache_t,
                        'latest_time': model.latest_time, 'global_emb': model.global_emb},
                       model_state_file)
                torch.save({'state_dict': global_model.state_dict(), 'epoch': epoch,
                            's_hist': model.s_hist_test, 's_cache': model.s_his_cache,
                            'o_hist': model.o_hist_test, 'o_cache': model.o_his_cache,
                            's_hist_t': model.s_hist_test_t, 's_cache_t': model.s_his_cache_t,
                            'o_hist_t': model.o_hist_test_t, 'o_cache_t': model.o_his_cache_t,
                            'latest_time': model.latest_time, 'global_emb': global_model.global_emb},
                           model_state_global_file2)
                with open(model_graph_file, 'wb') as fp:
                    pickle.dump(model.graph_dict, fp)

    print("training done")
Ejemplo n.º 5
0
recall_list = []
f1_list = []
f2_list = []
hloss_list = []

iterations = 0
while iterations < args.runs:
    iterations += 1
    print(
        '****************** iterations ',
        iterations,
    )

    if iterations == 1:
        print("loading data...")
        num_nodes, num_rels = utils.get_total_number(args.dp + args.dataset,
                                                     'stat.txt')

        with open('{}{}/100.w_emb'.format(args.dp, args.dataset), 'rb') as f:
            word_embeds = pickle.load(f, encoding="latin1")
        word_embeds = torch.FloatTensor(word_embeds)
        vocab_size = word_embeds.size(0)

        train_dataset_loader = DistData(args.dp,
                                        args.dataset,
                                        num_nodes,
                                        num_rels,
                                        set_name='train')
        valid_dataset_loader = DistData(args.dp,
                                        args.dataset,
                                        num_nodes,
                                        num_rels,
Ejemplo n.º 6
0
def test(args):
    # load data
    num_nodes, num_rels = utils.get_total_number('./data/' + args.dataset,
                                                 'stat.txt')
    test_data, test_times = utils.load_quadruples('./data/' + args.dataset,
                                                  'test.txt')
    total_data, total_times = utils.load_quadruples('./data/' + args.dataset,
                                                    'train.txt', 'valid.txt',
                                                    'test.txt')
    # check cuda
    use_cuda = args.gpu >= 0 and torch.cuda.is_available()
    if use_cuda:
        torch.cuda.set_device(args.gpu)
        torch.cuda.manual_seed_all(999)

    os.makedirs('models', exist_ok=True)
    if args.model == 0:
        model_state_file = 'models/' + args.dataset + 'attn.pth'
    elif args.model == 1:
        model_state_file = 'models/' + args.dataset + 'mean.pth'
    elif args.model == 2:
        model_state_file = 'models/' + args.dataset + 'gcn.pth'

    model = RENet(num_nodes,
                  args.n_hidden,
                  num_rels,
                  model=args.model,
                  seq_len=args.seq_len)

    if use_cuda:
        model.cuda()

    with open('./data/' + args.dataset + '/test_history_sub.txt', 'rb') as f:
        s_history_test = pickle.load(f)
    with open('./data/' + args.dataset + '/test_history_ob.txt', 'rb') as f:
        o_history_test = pickle.load(f)

    print("\nstart testing:")

    checkpoint = torch.load(model_state_file,
                            map_location=lambda storage, loc: storage)
    model.load_state_dict(checkpoint['state_dict'])
    model.s_hist_test = checkpoint['s_hist']
    model.s_his_cache = checkpoint['s_cache']
    model.o_hist_test = checkpoint['o_hist']
    model.o_his_cache = checkpoint['o_cache']
    model.latest_time = checkpoint['latest_time']

    print("Using best epoch: {}".format(checkpoint['epoch']))

    total_data = torch.from_numpy(total_data)
    test_data = torch.from_numpy(test_data)

    model.eval()
    total_loss = 0
    total_ranks = np.array([])
    total_ranks_filter = np.array([])
    ranks = []

    if use_cuda:
        total_data = total_data.cuda()

    latest_time = test_times[0]
    for i in range(len(test_data)):
        batch_data = test_data[i]
        s_hist = s_history_test[i]
        o_hist = o_history_test[i]
        if latest_time != batch_data[3]:
            ranks.append(total_ranks_filter)
            latest_time = batch_data[3]
            total_ranks_filter = np.array([])

        if use_cuda:
            batch_data = batch_data.cuda()

        with torch.no_grad():
            # Raw metric
            # ranks_filter, loss = model.evaluate(batch_data, s_hist, o_hist)

            # Filtered metric
            ranks_filter, loss = model.evaluate_filter(batch_data, s_hist,
                                                       o_hist, total_data)

            total_ranks_filter = np.concatenate(
                (total_ranks_filter, ranks_filter))
            total_loss += loss.item()

    ranks.append(total_ranks_filter)

    for rank in ranks:
        total_ranks = np.concatenate((total_ranks, rank))
    mrr = np.mean(1.0 / total_ranks)
    mr = np.mean(total_ranks)
    hits = []

    for hit in [1, 3, 10]:
        avg_count = np.mean((total_ranks <= hit))
        hits.append(avg_count)
        print("Hits (filtered) @ {}: {:.6f}".format(hit, avg_count))
    print("MRR (filtered): {:.6f}".format(mrr))
    print("MR (filtered): {:.6f}".format(mr))
Ejemplo n.º 7
0
def train(args):
    # load data
    num_nodes, num_rels = utils.get_total_number('./data/' + args.dataset,
                                                 'stat.txt')
    train_data, train_times_origin = utils.load_quadruples(
        './data/' + args.dataset, 'train.txt')

    # check cuda
    use_cuda = args.gpu >= 0 and torch.cuda.is_available()
    seed = 999
    np.random.seed(seed)
    torch.manual_seed(seed)
    if use_cuda:
        torch.cuda.set_device(args.gpu)

    os.makedirs('models', exist_ok=True)
    os.makedirs('models/' + args.dataset, exist_ok=True)
    if args.model == 0:
        model_state_file = 'models/' + args.dataset + 'attn.pth'
    elif args.model == 1:
        model_state_file = 'models/' + args.dataset + 'mean.pth'
    elif args.model == 2:
        model_state_file = 'models/' + args.dataset + 'gcn.pth'
    elif args.model == 3:
        model_state_file = 'models/' + args.dataset + '/max' + str(
            args.maxpool) + 'rgcn_global.pth'
        # model_graph_file = 'models/' + args.dataset + 'rgcn_graph.pth'
        model_state_file_backup = 'models/' + args.dataset + '/max' + str(
            args.maxpool) + 'rgcn__global_backup.pth'
        # model_graph_file_backup = 'models/' + args.dataset + 'rgcn_graph_backup.pth'

    print("start training...")
    model = RENet_global(num_nodes,
                         args.n_hidden,
                         num_rels,
                         dropout=args.dropout,
                         model=args.model,
                         seq_len=args.seq_len,
                         num_k=args.num_k,
                         maxpool=args.maxpool)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=0.00001)

    if use_cuda:
        model.cuda()

    # train_times = torch.from_numpy(train_times)
    with open('./data/' + args.dataset + '/train_graphs.txt', 'rb') as f:
        graph_dict = pickle.load(f)

    true_prob_s, true_prob_o = utils.get_true_distribution(
        train_data, num_nodes)

    epoch = 0
    loss_small = 10000
    while True:
        model.train()
        if epoch == args.max_epochs:
            break
        epoch += 1
        loss_epoch = 0
        t0 = time.time()
        # print(graph_dict.keys())
        # print(train_times_origin)

        train_times, true_prob_s, true_prob_o = shuffle(
            train_times_origin, true_prob_s, true_prob_o)

        for batch_data, true_s, true_o in utils.make_batch(
                train_times, true_prob_s, true_prob_o, args.batch_size):

            batch_data = torch.from_numpy(batch_data)
            true_s = torch.from_numpy(true_s)
            true_o = torch.from_numpy(true_o)
            if use_cuda:
                batch_data = batch_data.cuda()
                true_s = true_s.cuda()
                true_o = true_o.cuda()

            loss = model(batch_data, true_s, true_o, graph_dict)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           args.grad_norm)  # clip gradients
            optimizer.step()
            optimizer.zero_grad()
            loss_epoch += loss.item()

        t3 = time.time()
        model.global_emb = model.get_global_emb(train_times_origin, graph_dict)
        print("Epoch {:04d} | Loss {:.4f} | time {:.4f}".format(
            epoch, loss_epoch / (len(train_times) / args.batch_size), t3 - t0))
        if loss_epoch < loss_small:
            loss_small = loss_epoch
            if args.model == 3:
                torch.save(
                    {
                        'state_dict': model.state_dict(),
                        'global_emb': model.global_emb
                    }, model_state_file)
                # with open(model_graph_file, 'wb') as fp:
                #     pickle.dump(model.graph_dict, fp)
            else:
                torch.save(
                    {
                        'state_dict': model.state_dict(),
                        'epoch': epoch,
                        's_hist': model.s_hist_test,
                        's_cache': model.s_his_cache,
                        'o_hist': model.o_hist_test,
                        'o_cache': model.o_his_cache,
                        'latest_time': model.latest_time
                    }, model_state_file)

    print("training done")
Ejemplo n.º 8
0
def train(args):
    # load data
    num_nodes, num_rels, num_att = utils.get_total_number(
        args.dataset_path, 'stat.txt')
    train_data, train_times = utils.load_hexaruples(args.dataset_path,
                                                    'train.txt')

    # check cuda
    use_cuda = args.gpu >= 0 and torch.cuda.is_available()
    seed = 999
    np.random.seed(seed)
    torch.manual_seed(seed)
    if use_cuda:
        torch.cuda.set_device(args.gpu)

    model_dir = 'models/' + args.dataset + '/{}-{}-{}-{}'.format(
        args.dropout, args.n_hidden, args.gamma, args.num_k)

    os.makedirs('models', exist_ok=True)
    os.makedirs('models/' + args.dataset, exist_ok=True)
    os.makedirs(model_dir, exist_ok=True)

    print("start training...")
    model = DArtNet(num_nodes,
                    args.n_hidden,
                    num_rels,
                    num_att,
                    dropout=args.dropout,
                    model=args.model,
                    seq_len=args.seq_len,
                    num_k=args.num_k,
                    gamma=args.gamma)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=0.00001)

    if use_cuda:
        model.cuda()

    train_sub_entity = '/train_entity_s_history_data.txt'
    train_sub_rel = '/train_rel_s_history_data.txt'
    train_sub_att = '/train_att_s_history_data.txt'
    train_sub_self_att = '/train_self_att_s_history_data.txt'

    train_ob_entity = '/train_entity_o_history_data.txt'
    train_ob_rel = '/train_rel_o_history_data.txt'
    train_ob_att = '/train_att_o_history_data.txt'
    train_ob_self_att = '/train_self_att_o_history_data.txt'

    with open(args.dataset_path + train_sub_entity, 'rb') as f:
        entity_s_history_data_train = pickle.load(f)
    with open(args.dataset_path + train_sub_rel, 'rb') as f:
        rel_s_history_data_train = pickle.load(f)
    with open(args.dataset_path + train_sub_att, 'rb') as f:
        att_s_history_data_train = pickle.load(f)
    with open(args.dataset_path + train_sub_self_att, 'rb') as f:
        self_att_s_history_data_train = pickle.load(f)

    with open(args.dataset_path + train_ob_entity, 'rb') as f:
        entity_o_history_data_train = pickle.load(f)
    with open(args.dataset_path + train_ob_rel, 'rb') as f:
        rel_o_history_data_train = pickle.load(f)
    with open(args.dataset_path + train_ob_att, 'rb') as f:
        att_o_history_data_train = pickle.load(f)
    with open(args.dataset_path + train_ob_self_att, 'rb') as f:
        self_att_o_history_data_train = pickle.load(f)

    entity_s_history_train = entity_s_history_data_train
    rel_s_history_train = rel_s_history_data_train
    att_s_history_train = att_s_history_data_train
    self_att_s_history_train = self_att_s_history_data_train

    entity_o_history_train = entity_o_history_data_train
    rel_o_history_train = rel_o_history_data_train
    att_o_history_train = att_o_history_data_train
    self_att_o_history_train = self_att_o_history_data_train

    epoch = 0

    if args.retrain != 0:
        try:
            checkpoint = torch.load(model_dir + '/checkpoint.pth',
                                    map_location=f"cuda:{args.gpu}")
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            epoch = checkpoint['epoch']
            model.latest_time = checkpoint['latest_time']
            model.to(torch.device(f"cuda:{args.gpu}"))
        except FileNotFoundError as _:
            try:
                e = sorted([
                    int(file[6:-4])
                    for file in os.listdir(model_dir) if file[-4:] == '.pth'
                ],
                           reverse=True)[0]
                checkpoint = torch.load(model_dir + '/epoch-{}.pth'.format(e),
                                        map_location=f"cuda:{args.gpu}")
                model.load_state_dict(checkpoint['state_dict'])
                epoch = checkpoint['epoch']
                model.latest_time = checkpoint['latest_time']
                model.to(torch.device(f"cuda:{args.gpu}"))
            except Exception as _:
                print('no model found')
                print('training from scratch')

    while True:
        model.train()
        if epoch == args.max_epochs:
            break
        epoch += 1
        loss_epoch = 0
        loss_att_sub_epoch = 0
        # loss_att_ob_epoch = 0
        t0 = time.time()

        train_data, entity_s_history_train, rel_s_history_train, entity_o_history_train, rel_o_history_train, att_s_history_train, self_att_s_history_train, att_o_history_train, self_att_o_history_train = shuffle(
            train_data, entity_s_history_train, rel_s_history_train,
            entity_o_history_train, rel_o_history_train, att_s_history_train,
            self_att_s_history_train, att_o_history_train,
            self_att_o_history_train)

        iteration = 0
        for batch_data, s_hist, rel_s_hist, o_hist, rel_o_hist, att_s_hist, self_att_s_hist, att_o_hist, self_att_o_hist in utils.make_batch3(
                train_data, entity_s_history_train, rel_s_history_train,
                entity_o_history_train, rel_o_history_train,
                att_s_history_train, self_att_s_history_train,
                att_o_history_train, self_att_o_history_train,
                args.batch_size):
            iteration += 1
            print(f'iteration {iteration}', end='\r')
            batch_data = torch.from_numpy(batch_data)

            if use_cuda:
                batch_data = batch_data.cuda()

            loss, loss_att_sub = model.get_loss(batch_data, s_hist, rel_s_hist,
                                                att_s_hist, self_att_s_hist,
                                                o_hist, rel_o_hist, att_o_hist,
                                                self_att_o_hist)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           args.grad_norm)  # clip gradients
            optimizer.step()
            optimizer.zero_grad()
            loss_epoch += loss.item()
            loss_att_sub_epoch += loss_att_sub.item()
            # loss_att_ob_epoch += loss_att_ob.item()

        t3 = time.time()
        print(
            "Epoch {:04d} | Loss {:.4f} | Loss_att_sub {:.4f} | time {:.4f} ".
            format(epoch, loss_epoch / (len(train_data) / args.batch_size),
                   loss_att_sub_epoch / (len(train_data) / args.batch_size),
                   t3 - t0))

        torch.save(
            {
                'state_dict': model.state_dict(),
                'epoch': epoch,
                'latest_time': model.latest_time,
            }, model_dir + '/epoch-{}.pth'.format(epoch))

        torch.save(
            {
                'state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'epoch': epoch,
                'latest_time': model.latest_time,
            }, model_dir + '/checkpoint.pth')

    print("training done")