示例#1
0
def get_dev_candidate(args, raw_examples, dataset_list, model, verbose=True):
    logging.info("***** Running Prediction*****")
    model.eval()
    # get the last one (i.e., test) to make use if its useful functions and data
    standard_dataset = dataset_list[-1]

    ents = set()
    g_subj2objs = collections.defaultdict(lambda: collections.defaultdict(set))
    g_obj2subjs = collections.defaultdict(lambda: collections.defaultdict(set))
    for _ds in dataset_list:
        for _raw_ex in _ds.raw_examples:
            _head, _rel, _tail = _raw_ex
            ents.add(_head)
            ents.add(_tail)
            g_subj2objs[_head][_rel].add(_tail)
            g_obj2subjs[_tail][_rel].add(_head)
    ent_list = list(sorted(ents))
    # rel_list = list(sorted(standard_dataset.rel_list))
    rel_set = set()
    for triplet in raw_examples:
        _, r, _ = triplet
        rel_set.add(r)
    rel_list = list(sorted(list(rel_set)))

    # ========= get all embeddings ==========
    print("get all embeddings")

    save_dir = args.model_name_or_path
    save_path = os.path.join(args.model_name_or_path, "saved_dev_emb_mat.np")
    new_dev_path = os.path.join("./data/" + args.dataset + "/new_dev.dict")

    if dir_exists(save_dir) and file_exists(save_path):
        print("\tload from file")
        emb_mat = torch.load(save_path)
    else:
        print("\tget all ids")
        input_ids_list, mask_ids_list, segment_ids_list = [], [], []
        for _ent in tqdm(ent_list):
            for _idx_r, _rel in enumerate(rel_list):
                head_ids, rel_ids, tail_ids = standard_dataset.convert_raw_example_to_features(
                    [_ent, _rel, _ent], method="4")
                head_ids, rel_ids, tail_ids = head_ids[1:-1], rel_ids[
                    1:-1], tail_ids[1:-1]
                # truncate
                max_ent_len = standard_dataset.max_seq_length - 3 - len(
                    rel_ids)
                head_ids = head_ids[:max_ent_len]
                tail_ids = tail_ids[:max_ent_len]

                src_input_ids = [standard_dataset._cls_id] + head_ids + [
                    standard_dataset._sep_id
                ] + rel_ids + [standard_dataset._sep_id]
                src_mask_ids = [1] * len(src_input_ids)
                src_segment_ids = [0] * (len(head_ids) +
                                         2) + [1] * (len(rel_ids) + 1)

                if _idx_r == 0:
                    tgt_input_ids = [standard_dataset._cls_id
                                     ] + tail_ids + [standard_dataset._sep_id]
                    tgt_mask_ids = [1] * len(tgt_input_ids)
                    tgt_segment_ids = [0] * (len(tail_ids) + 2)
                    input_ids_list.append(tgt_input_ids)
                    mask_ids_list.append(tgt_mask_ids)
                    segment_ids_list.append(tgt_segment_ids)

                input_ids_list.append(src_input_ids)
                mask_ids_list.append(src_mask_ids)
                segment_ids_list.append(src_segment_ids)
        # # padding
        max_len = max(len(_e) for _e in input_ids_list)
        assert max_len <= standard_dataset.max_seq_length
        input_ids_list = [
            _e + [standard_dataset._pad_id] * (max_len - len(_e))
            for _e in input_ids_list
        ]
        mask_ids_list = [
            _e + [0] * (max_len - len(_e)) for _e in mask_ids_list
        ]
        segment_ids_list = [
            _e + [0] * (max_len - len(_e)) for _e in segment_ids_list
        ]
        # # dataset
        enc_dataset = TensorDataset(
            torch.tensor(input_ids_list, dtype=torch.long),
            torch.tensor(mask_ids_list, dtype=torch.long),
            torch.tensor(segment_ids_list, dtype=torch.long),
        )
        enc_dataloader = DataLoader(enc_dataset,
                                    sampler=SequentialSampler(enc_dataset),
                                    batch_size=args.eval_batch_size * 2)
        print("\tget all emb via model")
        embs_list = []
        rep_torch_dtype = None
        for batch in tqdm(enc_dataloader,
                          desc="entity embedding",
                          disable=(not verbose)):
            batch = tuple(t.to(args.device) for t in batch)
            _input_ids, _mask_ids, _segment_ids = batch
            with torch.no_grad():
                embs = model.encoder(_input_ids,
                                     attention_mask=_mask_ids,
                                     token_type_ids=_segment_ids)
                if rep_torch_dtype is None:
                    rep_torch_dtype = embs.dtype
                embs = embs.detach().cpu()
                embs_list.append(embs)

        emb_mat = torch.cat(embs_list, dim=0).contiguous()
        assert emb_mat.shape[0] == len(input_ids_list)
        # save emb_mat
        if dir_exists(save_dir):
            torch.save(emb_mat, save_path)

    # # assign to ent
    assert len(ent_list) * (1 + len(rel_list)) == emb_mat.shape[0]

    ent_rel2emb = collections.defaultdict(dict)  # ent + r  (h + r
    ent2emb = dict()  # ent  (t

    ptr_row = 0
    for _ent in ent_list:
        for _idx_r, _rel in enumerate(rel_list):
            if _idx_r == 0:
                ent2emb[_ent] = emb_mat[ptr_row]
                ptr_row += 1
            ent_rel2emb[_ent][_rel] = emb_mat[ptr_row]
            ptr_row += 1

    # ==========  get candidates ==========

    # * begin to get hit
    new_dev_dict = collections.defaultdict(dict)

    for _idx_ex, _triplet in enumerate(
            tqdm(raw_examples, desc="get candidates")):
        _head, _rel, _tail = _triplet

        head_ent_list = []
        tail_ent_list = []

        # head corrupt
        _pos_head_ents = g_obj2subjs[_tail][_rel]
        _neg_head_ents = ents - _pos_head_ents
        head_ent_list.append(_head)  # positive example
        head_ent_list.extend(_neg_head_ents)  # negative examples
        tail_ent_list.extend([_tail] * (1 + len(_neg_head_ents)))
        split_idx = len(head_ent_list)

        # tail corrupt
        _pos_tail_ents = g_subj2objs[_head][_rel]
        _neg_tail_ents = ents - _pos_tail_ents
        head_ent_list.extend([_head] * (1 + len(_neg_tail_ents)))
        tail_ent_list.append(_tail)  # positive example
        tail_ent_list.extend(_neg_tail_ents)  # negative examples

        triplet_list = list([_h, _rel, _t]
                            for _h, _t in zip(head_ent_list, tail_ent_list))

        # build dataset
        rep_src_list = [ent_rel2emb[_h][_rel] for _h, _rel, _ in triplet_list]
        rep_tgt_list = [ent2emb[_t] for _, _, _t in triplet_list]
        all_rep_src = torch.stack(rep_src_list, dim=0).to(args.device)
        all_rep_tgt = torch.stack(rep_tgt_list, dim=0).to(args.device)

        local_scores_list = []
        sim_batch_size = args.eval_batch_size * 8
        for _idx_r in range(0, all_rep_src.shape[0], sim_batch_size):
            _rep_src, _rep_tgt = all_rep_src[
                _idx_r:_idx_r + sim_batch_size], all_rep_tgt[_idx_r:_idx_r +
                                                             sim_batch_size]
            with torch.no_grad():
                logits = model.classifier(_rep_src, _rep_tgt)
                logits = torch.softmax(logits, dim=-1)
                local_scores = logits.detach().cpu().numpy()[:, 1]
            local_scores_list.append(local_scores)
        scores = np.concatenate(local_scores_list, axis=0)

        # left
        left_scores = scores[:split_idx]
        heads_corrupt_idx = get_triplet_candidate(left_scores)
        heads_corrupt = [head_ent_list[i] for i in heads_corrupt_idx]

        right_scores = scores[split_idx:]
        tails_corrupt_idx = get_triplet_candidate(right_scores)
        tails_corrupt = [
            tail_ent_list[i + split_idx] for i in tails_corrupt_idx
        ]

        new_dev_dict[tuple(_triplet)]["heads_corrupt"] = heads_corrupt
        new_dev_dict[tuple(_triplet)]["tails_corrupt"] = tails_corrupt
    torch.save(new_dev_dict, new_dev_path)
示例#2
0
def get_emb_mat(save_dir,
                save_path,
                ent_list,
                rel_list,
                dataset,
                args,
                model=None,
                verbose=True):
    if dir_exists(save_dir) and file_exists(save_path):
        logging.info("load from file")
        emb_mat = torch.load(save_path)
    else:
        logging.info("get all ids")
        input_ids_list, mask_ids_list, segment_ids_list = [], [], []
        for _ent in tqdm(ent_list):
            for _idx_r, _rel in enumerate(rel_list):
                head_ids, rel_ids, tail_ids = dataset.convert_raw_example_to_features(
                    [_ent, _rel, _ent], method="4")
                head_ids, rel_ids, tail_ids = head_ids[1:-1], rel_ids[
                    1:-1], tail_ids[1:-1]
                # truncate
                max_ent_len = dataset.max_seq_length - 3 - len(rel_ids)
                head_ids = head_ids[:max_ent_len]
                tail_ids = tail_ids[:max_ent_len]

                src_input_ids = [dataset._cls_id] + head_ids + [
                    dataset._sep_id
                ] + rel_ids + [dataset._sep_id]
                src_mask_ids = [1] * len(src_input_ids)
                src_segment_ids = [0] * (len(head_ids) +
                                         2) + [1] * (len(rel_ids) + 1)

                if _idx_r == 0:
                    tgt_input_ids = [dataset._cls_id
                                     ] + tail_ids + [dataset._sep_id]
                    tgt_mask_ids = [1] * len(tgt_input_ids)
                    tgt_segment_ids = [0] * (len(tail_ids) + 2)
                    input_ids_list.append(tgt_input_ids)
                    mask_ids_list.append(tgt_mask_ids)
                    segment_ids_list.append(tgt_segment_ids)

                input_ids_list.append(src_input_ids)
                mask_ids_list.append(src_mask_ids)
                segment_ids_list.append(src_segment_ids)

        # # padding
        max_len = max(len(_e) for _e in input_ids_list)
        assert max_len <= dataset.max_seq_length
        input_ids_list = [
            _e + [dataset._pad_id] * (max_len - len(_e))
            for _e in input_ids_list
        ]
        mask_ids_list = [
            _e + [0] * (max_len - len(_e)) for _e in mask_ids_list
        ]
        segment_ids_list = [
            _e + [0] * (max_len - len(_e)) for _e in segment_ids_list
        ]
        # # dataset
        enc_dataset = TensorDataset(
            torch.tensor(input_ids_list, dtype=torch.long),
            torch.tensor(mask_ids_list, dtype=torch.long),
            torch.tensor(segment_ids_list, dtype=torch.long),
        )
        enc_dataloader = DataLoader(enc_dataset,
                                    sampler=SequentialSampler(enc_dataset),
                                    batch_size=args.eval_batch_size * 2)
        logging.info("get all emb via model")
        embs_list = []
        for batch in tqdm(enc_dataloader,
                          desc="entity embedding",
                          disable=(not verbose)):
            batch = tuple(t.to(args.device) for t in batch)
            _input_ids, _mask_ids, _segment_ids = batch
            with torch.no_grad():
                embs = model.encoder(_input_ids,
                                     attention_mask=_mask_ids,
                                     token_type_ids=_segment_ids)
                embs = embs.detach().cpu()
                embs_list.append(embs)

        emb_mat = torch.cat(embs_list, dim=0).contiguous()
        assert emb_mat.shape[0] == len(input_ids_list)
        # save emb_mat
        if dir_exists(save_dir):
            torch.save(emb_mat, save_path)
    return emb_mat
示例#3
0
def get_scores(args,
               raw_examples,
               dataset,
               model,
               data_type=None,
               verbose=True):

    save_dir = args.model_name_or_path if dir_exists(
        args.model_name_or_path) else args.output_dir
    save_path = os.path.join(save_dir, "saved_emb_mat.np")
    head_scores_path = os.path.join(save_dir,
                                    data_type + "_head_full_scores.list")
    tail_scores_path = os.path.join(save_dir,
                                    data_type + "_tail_full_scores.list")
    if file_exists(head_scores_path) and file_exists(tail_scores_path):
        logging.info("Load head and tail mode scores")
        head_scores = torch.load(head_scores_path)
        tail_scores = torch.load(tail_scores_path)
        return head_scores, tail_scores
    model.eval()
    ent_list = list(sorted(list(dataset.ent2idx.keys())))
    rel_list = list(sorted(list(dataset.rel2idx.keys())))
    logging.info("Load all embeddings")
    emb_mat = get_emb_mat(save_dir, save_path, ent_list, rel_list, dataset,
                          args, model)
    ent2emb, ent_rel2emb = assign_emb2elements(ent_list, rel_list, emb_mat)
    # ========== get ranked logits score ==========
    head_scores = []
    tail_scores = []
    split_idx = len(dataset.id2ent_list)
    id2ent_list = dataset.id2ent_list
    for _idx_ex, _triplet in enumerate(
            tqdm(raw_examples, desc="get_" + data_type + "_scores")):
        _head, _rel, _tail = _triplet

        head_ent_list = id2ent_list
        tail_ent_list = [_tail] * split_idx

        head_ent_list = head_ent_list + [_head] * split_idx
        tail_ent_list = tail_ent_list + id2ent_list

        triplet_list = list([_h, _rel, _t]
                            for _h, _t in zip(head_ent_list, tail_ent_list))

        # build dataset
        rep_src_list = [ent_rel2emb[_h][_rel] for _h, _rel, _ in triplet_list]
        rep_tgt_list = [ent2emb[_t] for _, _rel, _t in triplet_list]
        all_rep_src = torch.stack(rep_src_list, dim=0).to(args.device)
        all_rep_tgt = torch.stack(rep_tgt_list, dim=0).to(args.device)

        local_logits_list = []
        sim_batch_size = args.eval_batch_size * 8
        for _idx_r in range(0, all_rep_src.shape[0], sim_batch_size):
            _rep_src, _rep_tgt = all_rep_src[
                _idx_r:_idx_r + sim_batch_size], all_rep_tgt[_idx_r:_idx_r +
                                                             sim_batch_size]

            with torch.no_grad():
                logits = model.classifier(_rep_src, _rep_tgt)
                logits = torch.softmax(logits, dim=-1)
                local_scores = logits.detach().cpu().numpy()[:, 1]
            local_logits_list.append(local_scores)

        sample_logits_list = np.concatenate(local_logits_list, axis=0)
        head_scores.append(
            [np.array(dataset.ent2idx[_head]), sample_logits_list[:split_idx]])
        tail_scores.append(
            [np.array(dataset.ent2idx[_tail]), sample_logits_list[split_idx:]])

    if dir_exists(save_dir):
        torch.save(head_scores, head_scores_path)
        torch.save(tail_scores, tail_scores_path)
    logging.info("Get scores finished")
    return head_scores, tail_scores
示例#4
0
def predict_NELL(args, raw_examples, dataset_list, model, verbose=True):
    logging.info("***** Running Prediction*****")
    model.eval()
    # get the last one (i.e., test) to make use if its useful functions and data
    standard_dataset = dataset_list[-1]

    ents = set()
    g_subj2objs = collections.defaultdict(lambda: collections.defaultdict(set))
    g_obj2subjs = collections.defaultdict(lambda: collections.defaultdict(set))
    for _ds in dataset_list:
        for _raw_ex in _ds.raw_examples:
            _head, _rel, _tail = _raw_ex
            ents.add(_head)
            ents.add(_tail)
            g_subj2objs[_head][_rel].add(_tail)
            g_obj2subjs[_tail][_rel].add(_head)
    ent_list = list(sorted(ents))
    # rel_list = list(sorted(standard_dataset.rel_list))
    rel_set = set()
    for triplet in raw_examples:
        _, r, _ = triplet
        rel_set.add(r)
    rel_list = list(sorted(list(rel_set)))

    # ========= get all embeddings ==========
    print("get all embeddings")

    save_dir = args.model_name_or_path if dir_exists(args.model_name_or_path) else args.output_dir
    save_path = os.path.join(save_dir, "saved_emb_mat.np")

    if dir_exists(save_dir) and file_exists(save_path):
        print("\tload from file")
        emb_mat = torch.load(save_path)
    else:
        print("\tget all ids")
        input_ids_list, mask_ids_list, segment_ids_list = [], [], []
        for _ent in tqdm(ent_list):
            for _idx_r, _rel in enumerate(rel_list):
                head_ids, rel_ids, tail_ids = standard_dataset.convert_raw_example_to_features(
                    [_ent, _rel, _ent], method="4")
                head_ids, rel_ids, tail_ids = head_ids[1:-1], rel_ids[1:-1], tail_ids[1:-1]
                # truncate
                max_ent_len = standard_dataset.max_seq_length - 3 - len(rel_ids)
                head_ids = head_ids[:max_ent_len]
                tail_ids = tail_ids[:max_ent_len]

                src_input_ids = [standard_dataset._cls_id] + head_ids + [standard_dataset._sep_id] + rel_ids + [
                    standard_dataset._sep_id]
                src_mask_ids = [1] * len(src_input_ids)
                src_segment_ids = [0] * (len(head_ids) + 2) + [1] * (len(rel_ids) + 1)

                if _idx_r == 0:
                    tgt_input_ids = [standard_dataset._cls_id] + tail_ids + [standard_dataset._sep_id]
                    tgt_mask_ids = [1] * len(tgt_input_ids)
                    tgt_segment_ids = [0] * (len(tail_ids) + 2)
                    input_ids_list.append(tgt_input_ids)
                    mask_ids_list.append(tgt_mask_ids)
                    segment_ids_list.append(tgt_segment_ids)

                input_ids_list.append(src_input_ids)
                mask_ids_list.append(src_mask_ids)
                segment_ids_list.append(src_segment_ids)

        # # padding
        max_len = max(len(_e) for _e in input_ids_list)
        assert max_len <= standard_dataset.max_seq_length
        input_ids_list = [_e + [standard_dataset._pad_id] * (max_len - len(_e)) for _e in input_ids_list]
        mask_ids_list = [_e + [0] * (max_len - len(_e)) for _e in mask_ids_list]
        segment_ids_list = [_e + [0] * (max_len - len(_e)) for _e in segment_ids_list]
        # # dataset
        enc_dataset = TensorDataset(
            torch.tensor(input_ids_list, dtype=torch.long),
            torch.tensor(mask_ids_list, dtype=torch.long),
            torch.tensor(segment_ids_list, dtype=torch.long),
        )
        enc_dataloader = DataLoader(
            enc_dataset, sampler=SequentialSampler(enc_dataset) , batch_size=args.eval_batch_size*2)
        print("\tget all emb via model")
        embs_list = []
        for batch in tqdm(enc_dataloader, desc="entity embedding", disable=(not verbose)):
            batch = tuple(t.to(args.device) for t in batch)
            _input_ids, _mask_ids, _segment_ids = batch
            with torch.no_grad():
                embs = model.encoder(_input_ids, attention_mask=_mask_ids, token_type_ids=_segment_ids)
                embs = embs.detach().cpu()
                embs_list.append(embs)

        emb_mat = torch.cat(embs_list, dim=0).contiguous()
        assert emb_mat.shape[0] == len(input_ids_list)
        # save emb_mat
        if dir_exists(save_dir):
            torch.save(emb_mat, save_path)

    # # assign to ent
    assert len(ent_list) *(1+len(rel_list)) == emb_mat.shape[0]

    ent_rel2emb = collections.defaultdict(dict)
    ent2emb = dict()


    ptr_row = 0
    for _ent in ent_list:
        for _idx_r, _rel in enumerate(rel_list):
            if _idx_r == 0:
                ent2emb[_ent] = emb_mat[ptr_row]
                ptr_row += 1
            ent_rel2emb[_ent][_rel] = emb_mat[ptr_row]
            ptr_row += 1

    # ========= run link prediction ==========

    # * begin to get hit
    ranks_left, ranks_right, ranks = [], [], []
    hits_left, hits_right, hits = [], [], []
    top_ten_hit_count = 0
    top_five_hit_count = 0
    top_one_hit_count = 0
    for i in range(10):
        hits_left.append([])
        hits_right.append([])
        hits.append([])
    for _idx_ex, _triplet in enumerate(tqdm(raw_examples, desc="evaluating")):
        _head, _rel, _tail = _triplet

        head_ent_list = []
        tail_ent_list = []

        # tail corrupt
        _pos_tail_ents = g_subj2objs[_head][_rel]
        _neg_tail_ents = ents - _pos_tail_ents
        _neg_tail_ents = [_ent for _ent in _neg_tail_ents if _ent in standard_dataset.type_dict[_rel]["tail"]]
        head_ent_list.extend([_head] * (1 + len(_neg_tail_ents)))
        tail_ent_list.append(_tail)  # positive example
        tail_ent_list.extend(_neg_tail_ents)  # negative examples

        triplet_list = list([_h, _rel, _t] for _h, _t in zip(head_ent_list, tail_ent_list))

        # build dataset
        rep_src_list = [ent_rel2emb[_h][_rel] for _h, _rel, _ in triplet_list]
        rep_tgt_list = [ent2emb[_t] for _, _rel, _t in triplet_list]
        all_rep_src = torch.stack(rep_src_list, dim=0).to(args.device)
        all_rep_tgt = torch.stack(rep_tgt_list, dim=0).to(args.device)

        local_scores_list = []
        sim_batch_size = args.eval_batch_size * 8
        if args.cls_method == "dis":
            for _idx_r in range(0, all_rep_src.shape[0], sim_batch_size):
                _rep_src, _rep_tgt = all_rep_src[_idx_r: _idx_r + sim_batch_size], all_rep_tgt[
                                                                                   _idx_r: _idx_r + sim_batch_size]
                with torch.no_grad():
                    distances = model.distance_metric_fn(_rep_src, _rep_tgt).to(torch.float32)
                    local_scores = - distances
                    local_scores = local_scores.detach().cpu().numpy()
                local_scores_list.append(local_scores)
        elif args.cls_method == "cls":
            for _idx_r in range(0, all_rep_src.shape[0], sim_batch_size):
                _rep_src, _rep_tgt = all_rep_src[_idx_r: _idx_r + sim_batch_size], all_rep_tgt[
                                                                                   _idx_r: _idx_r + sim_batch_size]
                with torch.no_grad():
                    logits = model.classifier(_rep_src, _rep_tgt).to(torch.float32)
                    logits = torch.softmax(logits, dim=-1)
                    local_scores = logits.detach().cpu().numpy()[:, 1]
                local_scores_list.append(local_scores)
        scores = np.concatenate(local_scores_list, axis=0)

        # right
        right_scores = scores
        right_rank = safe_ranking(right_scores)
        ranks_right.append(right_rank)
        ranks.append(right_rank)
        # log
        top_ten_hit_count += int(right_rank <= 10)
        top_five_hit_count += int(right_rank <= 5)
        top_one_hit_count += int(right_rank <= 1)
        if (_idx_ex + 1) % 10 == 0:
            logger.info("hit@1 until now: {}".format(top_one_hit_count * 1.0 / len(ranks)))
            logger.info("hit@5 until now: {}".format(top_five_hit_count * 1.0 / len(ranks)))
            logger.info("hit@10 until now: {}".format(top_ten_hit_count * 1.0 / len(ranks)))


        # hits
        for hits_level in range(10):

            if right_rank <= hits_level + 1:
                hits[hits_level].append(1.0)
                hits_right[hits_level].append(1.0)
            else:
                hits[hits_level].append(0.0)
                hits_right[hits_level].append(0.0)


    if verbose:
        for i in [0, 4, 9]:
            logger.info('Hits right @{0}: {1}'.format(i + 1, np.mean(hits_right[i])))
        logger.info('Mean rank right: {0}'.format(np.mean(ranks_right)))
        logger.info('Mean reciprocal rank right: {0}'.format(np.mean(1. / np.array(ranks_right))))

        with open(join(args.output_dir, "link_prediction_metrics.txt"), "w", encoding="utf-8") as fp:
            for i in [0, 4, 9]:
                fp.write('Hits right @{0}: {1}\n'.format(i + 1, np.mean(hits_right[i])))
            fp.write('Mean rank right: {0}\n'.format(np.mean(ranks_right)))
            fp.write('Mean reciprocal rank right: {0}\n'.format(np.mean(1. / np.array(ranks_right))))
        print("save finished!")

        tuple_ranks = [[int(_l), int(_r)] for _l, _r in zip(ranks_left, ranks_right)]
        return tuple_ranks