def eval_mean():
    start = time.time()
    model = BERTModel(bert_pre_model=BERT_TINY_MODEL)
    model = nn.DataParallel(model)
    kb_id_token = pd.read_pickle('data/kb_ids.pkl')
    query_id_token = pd.read_pickle('data/train_ids.pkl')
    subject_ids = list(kb_id_token.keys())
    dataset = get_dataset()[70000:]

    val_dataset = dataset
    merge_scores = []
    for index in range(1):

        model.load_state_dict(
            torch.load('data/model' + str(index) + '_to100.pt'))
        model.cuda()
        model.eval()
        kb_preds = []
        kb_loader = tqdm(DataLoaderKb(subject_ids, batch_size=8))
        for batch_index, batch in enumerate(kb_loader):
            batch = [i.to('cuda') for i in batch]
            anchor_out = model(batch).detach()
            kb_preds.append(anchor_out)
        kb_preds = torch.cat(kb_preds, dim=0)

        val_loader = DataLoaderTrain(val_dataset, batch_size=8, shuffle=False)

        tqdm_batch_iterator = tqdm(val_loader)
        val_preds = []
        for batch_index, batch in enumerate(tqdm_batch_iterator):
            batch = [i.to('cuda') for i in batch]
            anchor_ids = batch[0]
            anchor_mask = batch[1]
            anchor_out = model([anchor_ids, anchor_mask]).detach()
            for i in range(len(anchor_out)):
                val_preds.append(anchor_out[i])
        scores = []
        for pred, data in zip(val_preds, val_dataset):
            score = F.pairwise_distance(pred, kb_preds, p=2).cpu().numpy()
            scores.append(score)
        scores = np.array(scores)
        merge_scores.append(scores)
    recall_num = 0
    recall_top100 = {}
    merge_scores = np.mean(merge_scores, axis=0)
    for data, scores in zip(val_dataset, merge_scores):
        text_id = data[0]
        subject_id = data[1]
        indices = scores.argsort()[:100]
        recall_subject_ids = [subject_ids[index] for index in indices]
        if subject_id in recall_subject_ids:
            recall_num += 1
        recall_top100[text_id] = recall_subject_ids
    print(recall_num / len(val_dataset))
    print(recall_num)
    pd.to_pickle(recall_top100, 'data/recall_top100.pkl')
    print(time.time() - start)
def predict_union():
    dataset = get_test()
    model = BERTModel(bert_pre_model=BERT_TINY_MODEL)
    model = nn.DataParallel(model)

    def predict_(index):

        model.load_state_dict(
            torch.load('data/model' + str(index) + '_to100.pt'))
        model.cuda()
        model.eval()
        kb_id_token = pd.read_pickle('data/kb_ids.pkl')
        subject_ids = list(kb_id_token.keys())

        kb_preds = []
        kb_loader = tqdm(DataLoaderKb(subject_ids, batch_size=8))
        for batch_index, batch in enumerate(kb_loader):
            batch = [i.to('cuda') for i in batch]
            anchor_out = model(batch).detach()
            kb_preds.append(anchor_out)
        kb_preds = torch.cat(kb_preds, dim=0)
        data_loader = DataLoaderTest(dataset, batch_size=8, shuffle=False)
        recall_top100 = {}
        recall_top100_score = {}
        tqdm_batch_iterator = tqdm(data_loader)
        val_preds = []
        for batch_index, batch in enumerate(tqdm_batch_iterator):
            batch = [i.to('cuda') for i in batch]
            anchor_ids = batch[0]
            anchor_mask = batch[1]
            anchor_out = model([anchor_ids, anchor_mask]).detach()
            for i in range(len(anchor_out)):
                val_preds.append(anchor_out[i])

        recall_num = 0
        for pred, data in zip(val_preds, dataset):
            text_id = data[0]
            subject_id = data[1]
            scores = F.pairwise_distance(pred, kb_preds, p=2).cpu().numpy()
            indices = scores.argsort()[:100]
            recall_subject_ids = [subject_ids[index] for index in indices]
            if subject_id in recall_subject_ids:
                recall_num += 1
            recall_top100_score[text_id] = scores[indices]
            recall_top100[text_id] = recall_subject_ids
        print(recall_num / len(dataset))
        print(recall_num)

        return recall_top100, recall_top100_score

    recall_top100_all = []
    recall_top100_score_all = []
    top_len = 0
    for i in range(1):
        top100, top100_score = predict_(i)
        recall_top100_all.append(top100)
        recall_top100_score_all.append(top100_score)

    recall_num = 0
    recall_top100 = {}
    for data in dataset:
        text_id = data[0]
        subject_id = data[1]
        recall_subject_ids = dict()

        for top100, top100_score in zip(recall_top100_all,
                                        recall_top100_score_all):
            top_subject = top100[text_id]
            top_score = top100_score[text_id]
            for sub, sco in zip(top_subject, top_score):
                if sub in recall_subject_ids:
                    if sco < recall_subject_ids[sub]:
                        recall_subject_ids[sub] = sco
                else:
                    recall_subject_ids[sub] = sco
        recall_subject_ids = sorted(recall_subject_ids.items(),
                                    key=lambda d: d[1],
                                    reverse=False)
        recall_subject_ids = [i[0] for i in recall_subject_ids][:100]
        top_len += len(recall_subject_ids)

        if subject_id in recall_subject_ids:
            recall_num += 1
        recall_top100[text_id] = list(recall_subject_ids)
    print(recall_num / len(dataset))
    print(recall_num)
    print(top_len / len(dataset))
    pd.to_pickle(recall_top100, 'data/test_recall_top100.pkl')
def train(index):
    model = BERTModel(bert_pre_model=BERT_TINY_MODEL)

    dataset = get_dataset()
    # train_dataset = dataset[:index * 10000] + dataset[(index + 1) * 10000:]
    # val_dataset = dataset[index * 10000:(index + 1) * 10000]
    train_dataset = dataset[:500]
    val_dataset = dataset[500:1000]

    print(len(dataset))
    print(len(train_dataset))
    print(len(val_dataset))

    train_loader = DataLoaderTrain(train_dataset, batch_size=8)
    val_loader = DataLoaderTrain(val_dataset, batch_size=8)

    device = 'cuda'
    model.to(device)
    bert_model_params = list(map(id, model.bert_model.parameters()))
    base_params = filter(lambda p: id(p) not in bert_model_params,
                         model.parameters())
    params_list = [{
        "params": base_params,
        'lr': 5e-04
    }, {
        'params': model.bert_model.parameters(),
        'lr': 3e-05
    }]
    device_ids = range(torch.cuda.device_count())
    print(device_ids)
    model = nn.DataParallel(model)

    optimizer = torch.optim.Adam(
        params_list,
        lr=3e-05)  # torch.optim.Adam(params=model.parameters(), lr=2e-05)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="min",
        factor=0.2,
        patience=1,
        min_lr=5e-6,
    )
    earlystopping = EarlyStopping(model)
    mc = ModelCheckpoint(model, 'data/model' + str(index) + '_to100.pt')
    loss_function = torch.nn.TripletMarginLoss(margin=3)
    # epochs = 15
    epochs = 1
    for i in range(epochs):
        model.train()
        # tqdm_batch_iterator = tqdm(train_loader)
        ave_loss = 0
        pbar = ProgressBar(n_total=len(train_loader),
                           desc=f'training{i}/{epochs}')
        for batch_index, batch in enumerate(train_loader):
            batch = [i.to('cuda') for i in batch]
            anchor_ids = batch[0]
            anchor_mask = batch[1]
            positive_ids = batch[2]
            positive_mask = batch[3]
            negative_ids = batch[4]
            negative_mask = batch[5]
            optimizer.zero_grad()
            anchor_out = model([anchor_ids, anchor_mask])
            positive_out = model([positive_ids, positive_mask])
            negative_out = model([negative_ids, negative_mask])

            loss = loss_function(anchor_out, positive_out, negative_out)
            ave_loss += loss.item() / len(train_loader)
            loss.backward()
            optimizer.step()
            pbar(step=batch_index, info={'loss': loss})
        print(ave_loss)
        model.eval()
        val_loss = 0.0
        preds, labels = [], []
        pos_distance = []
        neg_distance = []
        with torch.no_grad():
            tqdm_batch_iterator = val_loader
            for batch_index, batch in enumerate(tqdm_batch_iterator):
                batch = [i.to('cuda') for i in batch]
                anchor_ids = batch[0]
                anchor_mask = batch[1]
                positive_ids = batch[2]
                positive_mask = batch[3]
                negative_ids = batch[4]
                negative_mask = batch[5]
                anchor_out = model([anchor_ids, anchor_mask])
                positive_out = model([positive_ids, positive_mask])
                negative_out = model([negative_ids, negative_mask])
                loss = loss_function(anchor_out, positive_out, negative_out)
                val_loss += loss.item() / len(val_loader)
                dist_pos = F.pairwise_distance(anchor_out, positive_out,
                                               p=2).detach().cpu().numpy()
                pos_distance.extend(dist_pos)
                dist_neg = F.pairwise_distance(anchor_out, negative_out,
                                               p=2).detach().cpu().numpy()
                neg_distance.extend(dist_neg)
        scheduler.step(val_loss)
        earlystopping.step(val_loss)
        mc.epoch_step(val_loss)
        print('val loss', val_loss)
    torch.save(model.state_dict(), 'data/model' + str(index) + '_to100.pt')
Exemple #4
0
def predict():
    start = time.time()
    model = BERTModel(bert_pre_model=BERT_TINY_MODEL)
    model = nn.DataParallel(model)
    kb_id_token = pd.read_pickle('data/kb_ids.pkl')
    subject_ids = list(kb_id_token.keys())
    dataset = get_test()

    merge_scores = []
    recall_top10 = pd.read_pickle('data/test_recall_top10.pkl')
    for index in range(1):

        model.load_state_dict(
            torch.load('data/model' + str(index) + '_to10.pt'))
        model.cuda()
        model.eval()
        kb_preds = []
        kb_loader = tqdm(DataLoaderKb(subject_ids, batch_size=8))
        for batch_index, batch in enumerate(kb_loader):
            batch = [i.to('cuda') for i in batch]
            anchor_out = model(batch).detach()
            kb_preds.append(anchor_out)
        kb_preds = torch.cat(kb_preds, dim=0)

        kb_preds_dict = dict(zip(subject_ids, kb_preds))

        val_loader = DataLoaderTest(dataset, batch_size=4, shuffle=False)

        tqdm_batch_iterator = tqdm(val_loader)
        val_preds = []
        for batch_index, batch in enumerate(tqdm_batch_iterator):
            batch = [i.to('cuda') for i in batch]
            anchor_ids = batch[0]
            anchor_mask = batch[1]
            anchor_out = model([anchor_ids, anchor_mask]).detach()
            for i in range(len(anchor_out)):
                val_preds.append(anchor_out[i])

        scores = []
        for pred, data in zip(val_preds, dataset):
            query_id = data[0]
            subject_id = data[1]
            recall_10 = np.array(recall_top10[query_id])
            recall_10_pred = [
                kb_preds_dict[recall_id].view(1, pred.size(0))
                for recall_id in recall_10
            ]
            recall_10_pred = torch.cat(recall_10_pred, dim=0)
            score = F.pairwise_distance(pred, recall_10_pred,
                                        p=2).cpu().numpy()
            scores.append(score)
        scores = np.array(scores)
        merge_scores.append(scores)
    recall_num = 0
    recall_top1 = {}
    merge_scores = np.mean(merge_scores, axis=0)
    for data, scores in zip(dataset, merge_scores):
        text_id = data[0]
        recall_10 = np.array(recall_top10[text_id])
        indices = scores.argsort()[:1]
        recall_subject_ids = [recall_10[index] for index in indices]
        recall_top1[text_id] = recall_subject_ids
    print(recall_num / len(dataset))
    print(recall_num)
    pd.to_pickle(recall_top1, 'data/test_recall_top1.pkl')
    print(time.time() - start)
def predict_union():
    dataset = get_test()
    model = BERTModel(bert_pre_model=BERT_TINY_MODEL)
    model = nn.DataParallel(model)

    def predict_(index):

        model.load_state_dict(
            torch.load('data/model' + str(index) + '_to10.pt'))
        model.cuda()
        model.eval()
        kb_id_token = pd.read_pickle('data/kb_ids.pkl')
        subject_ids = list(kb_id_token.keys())

        kb_preds = []
        kb_loader = tqdm(DataLoaderKb(subject_ids, batch_size=48))
        for batch_index, batch in enumerate(kb_loader):
            batch = [i.to('cuda') for i in batch]
            anchor_out = model(batch).detach()
            kb_preds.append(anchor_out)
        kb_preds = torch.cat(kb_preds, dim=0)

        kb_preds_dict = dict(zip(subject_ids, kb_preds))
        recall_top100 = pd.read_pickle('data/test_recall_top100.pkl')
        data_loader = DataLoaderTest(dataset, batch_size=64, shuffle=False)

        tqdm_batch_iterator = tqdm(data_loader)
        val_preds = []
        for batch_index, batch in enumerate(tqdm_batch_iterator):
            batch = [i.to('cuda') for i in batch]
            anchor_ids = batch[0]
            anchor_mask = batch[1]
            anchor_out = model([anchor_ids, anchor_mask]).detach()
            for i in range(len(anchor_out)):
                val_preds.append(anchor_out[i])
        recall_num = 0
        recall_top10 = {}
        recall_top10_score = {}
        for pred, data in zip(val_preds, dataset):
            query_id = data[0]
            subject_id = data[1]
            recall_100 = np.array(recall_top100[query_id])
            recall_100_pred = [
                kb_preds_dict[recall_id].view(1, 312)
                for recall_id in recall_100
            ]
            recall_100_pred = torch.cat(recall_100_pred, dim=0)
            scores = F.pairwise_distance(pred, recall_100_pred,
                                         p=2).cpu().numpy()
            indices = scores.argsort()[:10]
            recall_subject_ids = [recall_100[index] for index in indices]
            if subject_id in recall_subject_ids:
                recall_num += 1
            recall_top10[query_id] = recall_subject_ids
            recall_top10_score[query_id] = scores[indices]
        print(recall_num / len(dataset))
        print(recall_num)
        return recall_top10, recall_top10_score

    recall_top10_all = []
    top_len = 0
    for i in range(8):
        recall_top10_all.append(predict_(i))
    recall_num = 0
    recall_top10 = {}
    for data in dataset:
        text_id = data[0]
        subject_id = data[1]
        recall_subject_ids = set()
        for top10 in recall_top10_all:
            recall_subject_ids = set(top10[text_id]) | recall_subject_ids
        top_len += len(recall_subject_ids)
        if subject_id in recall_subject_ids:
            recall_num += 1
        recall_top10[text_id] = list(recall_subject_ids)
    pd.to_pickle(recall_top10, 'data/test_recall_top10.pkl')
    print(recall_num / len(dataset))
    print(recall_num)
    print(top_len / len(dataset))
def eval_union():
    dataset = get_dataset()
    val_dataset = dataset
    start = time.time()
    model = BERTModel(bert_pre_model=BERT_TINY_MODEL)
    model = nn.DataParallel(model)

    def eval_(index):

        model.load_state_dict(
            torch.load('data/model' + str(index) + '_to10.pt'))
        model.cuda()
        model.eval()
        kb_id_token = pd.read_pickle('data/kb_ids.pkl')
        subject_ids = list(kb_id_token.keys())

        kb_preds = []
        kb_loader = tqdm(DataLoaderKb(subject_ids, batch_size=8))
        for batch_index, batch in enumerate(kb_loader):
            batch = [i.to('cuda') for i in batch]
            anchor_out = model(batch).detach()
            kb_preds.append(anchor_out)
        kb_preds = torch.cat(kb_preds, dim=0)

        kb_preds_dict = dict(zip(subject_ids, kb_preds))
        recall_top100 = pd.read_pickle('data/recall_top100.pkl')
        val_loader = DataLoaderTrain(val_dataset, batch_size=4, shuffle=False)

        tqdm_batch_iterator = tqdm(val_loader)
        val_preds = []
        for batch_index, batch in enumerate(tqdm_batch_iterator):
            batch = [i.to('cuda') for i in batch]
            anchor_ids = batch[0]
            anchor_mask = batch[1]
            anchor_out = model([anchor_ids, anchor_mask]).detach()
            for i in range(len(anchor_out)):
                val_preds.append(anchor_out[i])
        recall_num = 0
        recall_top10 = {}
        recall_top10_score = {}
        for pred, data in zip(val_preds, val_dataset):
            query_id = data[0]
            subject_id = data[1]
            recall_100 = np.array(recall_top100[query_id])
            recall_100_pred = [
                kb_preds_dict[recall_id] for recall_id in recall_100
            ]
            recall_100_pred = torch.cat(recall_100_pred, dim=0)
            scores = F.pairwise_distance(pred,
                                         recall_100_pred.reshape(-1, 768),
                                         p=2).cpu().numpy()
            indices = scores.argsort()[:10]
            recall_subject_ids = [recall_100[index] for index in indices]
            if subject_id in recall_subject_ids:
                recall_num += 1
            recall_top10[query_id] = recall_subject_ids
            recall_top10_score[query_id] = scores[indices]
        print(recall_num / len(val_dataset))
        print(recall_num)
        return recall_top10, recall_top10_score

    # recall_top10_all = []
    # recall_top10_score_all = []
    # top_len = 0
    # for i in range(8):
    #     top10, top10_score = eval_(i)
    #     recall_top10_all.append(top10)
    #     recall_top10_score_all.append(top100_score)
    #
    # recall_num = 0
    # recall_top100 = {}
    # for data in dataset:
    #     text_id = data[0]
    #     subject_id = data[1]
    #     # recall_subject_ids = []
    #     recall_subject_ids = dict()
    #
    #     for top100, top100_score in zip(recall_top100_all, recall_top100_score_all):
    #         top_subject = top100[text_id]
    #         top_score = top100_score[text_id]
    #         for sub, sco in zip(top_subject, top_score):
    #             if sub in recall_subject_ids:
    #                 if sco < recall_subject_ids[sub]:
    #                     recall_subject_ids[sub] = sco
    #             else:
    #                 recall_subject_ids[sub] = sco
    #     recall_subject_ids = sorted(recall_subject_ids.items(), key=lambda d: d[1], reverse=False)
    #     recall_subject_ids = [i[0] for i in recall_subject_ids][:100]
    #

    recall_top10_all = []
    recall_top10_score_all = []
    top_len = 0
    for i in range(1):
        top10, top10_score = eval_(i)
        recall_top10_all.append(top10)
        recall_top10_score_all.append(top10_score)
    recall_num = 0
    recall_top10 = {}
    for data in val_dataset:
        text_id = data[0]
        subject_id = data[1]

        recall_subject_ids = dict()
        #
        for top100, top100_score in zip(recall_top10_all,
                                        recall_top10_score_all):
            top_subject = top100[text_id]
            top_score = top100_score[text_id]
            for sub, sco in zip(top_subject, top_score):
                if sub in recall_subject_ids:
                    if sco < recall_subject_ids[sub]:
                        recall_subject_ids[sub] = sco
                else:
                    recall_subject_ids[sub] = sco
        recall_subject_ids = sorted(recall_subject_ids.items(),
                                    key=lambda d: d[1],
                                    reverse=False)
        recall_subject_ids = [i[0] for i in recall_subject_ids][:10]

        top_len += len(recall_subject_ids)

        if subject_id in recall_subject_ids:
            recall_num += 1
        recall_top10[text_id] = list(recall_subject_ids)
    pd.to_pickle(recall_top10, 'data/recall_top10.pkl')
    print(recall_num / len(val_dataset))
    print(recall_num)
    print(top_len / len(val_dataset))