Ejemplo n.º 1
0
def main(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info('Loading model.')
    with open(args.config, 'r') as f:
        params = json.load(f)
    model = BiEncoderRanker(params)
    model.load_model(args.ckpt)
    model.to(device)

    logger.info('Loading data.')
    dataset, entity_ids = load_dataset(args.input, model.tokenizer,
                                       args.max_seq_length)
    sampler = torch.utils.data.SequentialSampler(dataset)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             sampler=sampler)

    with open(args.output, 'w') as g:
        for i, (tokens, entity_id) in enumerate(zip(dataloader, entity_ids)):
            tokens = tokens[0].to(device)
            with torch.no_grad():
                encodings = model.encode_context(tokens)
            serialized = '\t'.join(str(x) for x in encodings[0].tolist())
            line = f'{i}\t{entity_id}\t{serialized}\n'
            g.write(line)
Ejemplo n.º 2
0
def main(params):
    output_path = params["output_path"]
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = utils.get_logger(params["output_path"])

    # Init model
    reranker = BiEncoderRanker(params)
    tokenizer = reranker.tokenizer
    model = reranker.model

    device = reranker.device

    cand_encode_path = params.get("cand_encode_path", None)

    # candidate encoding is not pre-computed.
    # load/generate candidate pool to compute candidate encoding.
    cand_pool_path = params.get("cand_pool_path", None)
    candidate_pool = load_or_generate_candidate_pool(
        tokenizer,
        params,
        logger,
        cand_pool_path,
    )

    candidate_encoding = None
    if cand_encode_path is not None:
        # try to load candidate encoding from path
        # if success, avoid computing candidate encoding
        try:
            logger.info("Loading pre-generated candidate encode path.")
            candidate_encoding = torch.load(cand_encode_path)
        except:
            logger.info("Loading failed. Generating candidate encoding.")

    if candidate_encoding is None:
        candidate_encoding = encode_candidate_zeshel(
            reranker,
            candidate_pool,
            params["encode_batch_size"],
            silent=params["silent"],
            logger=logger,
        )

        if cand_encode_path is not None:
            # Save candidate encoding to avoid re-compute
            logger.info("Saving candidate encoding to file " +
                        cand_encode_path)
            torch.save(cand_encode_path, candidate_encoding)

    test_samples = utils.read_dataset(params["mode"], params["data_path"])
    logger.info("Read %d test samples." % len(test_samples))

    test_data, test_tensor_data = data.process_mention_data(
        test_samples,
        tokenizer,
        params["max_context_length"],
        params["max_cand_length"],
        context_key=params['context_key'],
        silent=params["silent"],
        logger=logger,
        debug=params["debug"],
    )
    test_sampler = SequentialSampler(test_tensor_data)
    test_dataloader = DataLoader(test_tensor_data,
                                 sampler=test_sampler,
                                 batch_size=params["encode_batch_size"])

    save_results = params.get("save_topk_result")
    new_data = nnquery.get_topk_predictions(
        reranker,
        test_dataloader,
        candidate_pool,
        candidate_encoding,
        params["silent"],
        logger,
        params["top_k"],
        params.get("zeshel", None),
        save_results,
    )

    if save_results:
        save_data_path = os.path.join(
            params['output_path'],
            'candidates_%s_top%d.t7' % (params['mode'], params['top_k']))
        torch.save(new_data, save_data_path)
Ejemplo n.º 3
0
def main(params):
    output_path = params["output_path"]
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = utils.get_logger(params["output_path"])

    # Init model
    reranker = BiEncoderRanker(params)
    tokenizer = reranker.tokenizer
    model = reranker.model

    device = reranker.device

    cand_encode_path = params.get("cand_encode_path", None)

    # candidate encoding is not pre-computed.
    # load/generate candidate pool to compute candidate encoding.
    cand_pool_path = params.get("cand_pool_path", None)
    candidate_pool = load_or_generate_candidate_pool(
        tokenizer,
        params,
        logger,
        cand_pool_path,
    )

    candidate_encoding = None
    if cand_encode_path is not None:
        # try to load candidate encoding from path
        # if success, avoid computing candidate encoding
        try:
            logger.info("Loading pre-generated candidate encode path.")
            candidate_encoding = torch.load(cand_encode_path)
        except:
            logger.info("Loading failed. Generating candidate encoding.")

    if candidate_encoding is None:
        candidate_encoding = encode_candidate(reranker,
                                              candidate_pool,
                                              params["encode_batch_size"],
                                              silent=params["silent"],
                                              logger=logger,
                                              is_zeshel=params.get(
                                                  "zeshel", None))

        if cand_encode_path is not None:
            # Save candidate encoding to avoid re-compute
            logger.info("Saving candidate encoding to file " +
                        cand_encode_path)
            torch.save(candidate_encoding, cand_encode_path)

    test_samples = utils.read_dataset(params["mode"], params["data_path"])

    # test_samples_custom = utils.read_dataset(params["mode"], 'data/zeshel/processed')

    # # Copy custom dataset to original except 'label_id'
    # for i in range(len(test_samples)):
    #     for k in test_samples[i]:
    #         if k == 'label_id':
    #             continue
    #         k_custom = 'type' if k == 'world' else k
    #         test_samples[i][k] = test_samples_custom[i][k_custom]

    logger.info("Read %d test samples." % len(test_samples))

    test_data, test_tensor_data = data.process_mention_data(
        test_samples,
        tokenizer,
        params["max_context_length"],
        params["max_cand_length"],
        context_key=params['context_key'],
        silent=params["silent"],
        logger=logger,
        debug=params["debug"],
    )

    # custom_embeds = torch.load('models/trained/zeshel_og/eval/data_og/custom_embed.t7')
    # dict_idxs_by_type = torch.load('models/trained/zeshel_og/eval/data_og/dict_idx_mapping.t7')
    # test_dict_vecs = torch.load('models/trained/zeshel_og/eval/data_og/custom_test_ent_vecs.t7')
    # test_men_vecs = torch.load('models/trained/zeshel_og/eval/data_og/custom_test_men_vecs.t7')
    # world_to_type = {12:'forgotten_realms', 13:'lego', 14:'star_trek', 15:'yugioh'}
    # embed()

    test_sampler = SequentialSampler(test_tensor_data)
    test_dataloader = DataLoader(test_tensor_data,
                                 sampler=test_sampler,
                                 batch_size=params["encode_batch_size"])

    save_results = params.get("save_topk_result")
    new_data = nnquery.get_topk_predictions(
        reranker,
        test_dataloader,
        candidate_pool,
        candidate_encoding,
        params["silent"],
        logger,
        params["top_k"],
        params.get("zeshel", None),
        save_results,
    )

    if save_results:
        save_data_path = os.path.join(
            params['output_path'],
            'candidates_%s_top%d.t7' % (params['mode'], params['top_k']))
        torch.save(new_data, save_data_path)
Ejemplo n.º 4
0
def get_biencoder(parameters):
    return BiEncoderRanker(parameters)
Ejemplo n.º 5
0
def main(params):
    output_path = params["output_path"]
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = utils.get_logger(params["output_path"], 'log-eval')

    pickle_src_path = params["pickle_src_path"]
    if pickle_src_path is None or not os.path.exists(pickle_src_path):
        pickle_src_path = output_path

    embed_data_path = params["embed_data_path"]
    if embed_data_path is None or not os.path.exists(embed_data_path):
        embed_data_path = output_path

    cands_path = os.path.join(output_path, 'cands.pickle')

    # Init model
    reranker = BiEncoderRanker(params)
    reranker.model.eval()
    tokenizer = reranker.tokenizer
    n_gpu = reranker.n_gpu

    knn = params["knn"]
    use_types = params["use_types"]
    within_doc = params["within_doc"]
    data_split = params["data_split"]  # Default = "test"

    # Load test data
    test_samples = None
    entity_dictionary_loaded = False
    test_dictionary_pkl_path = os.path.join(pickle_src_path,
                                            'test_dictionary.pickle')
    test_tensor_data_pkl_path = os.path.join(pickle_src_path,
                                             'test_tensor_data.pickle')
    test_mention_data_pkl_path = os.path.join(pickle_src_path,
                                              'test_mention_data.pickle')
    if params['transductive']:
        train_tensor_data_pkl_path = os.path.join(pickle_src_path,
                                                  'train_tensor_data.pickle')
        train_mention_data_pkl_path = os.path.join(
            pickle_src_path, 'train_mention_data.pickle')
    if os.path.isfile(test_dictionary_pkl_path):
        print("Loading stored processed entity dictionary...")
        with open(test_dictionary_pkl_path, 'rb') as read_handle:
            test_dictionary = pickle.load(read_handle)
        entity_dictionary_loaded = True
    if os.path.isfile(test_tensor_data_pkl_path) and os.path.isfile(
            test_mention_data_pkl_path):
        print("Loading stored processed test data...")
        with open(test_tensor_data_pkl_path, 'rb') as read_handle:
            test_tensor_data = pickle.load(read_handle)
        with open(test_mention_data_pkl_path, 'rb') as read_handle:
            mention_data = pickle.load(read_handle)
    else:
        test_samples = utils.read_dataset(data_split, params["data_path"])
        if not entity_dictionary_loaded:
            with open(os.path.join(params["data_path"], 'dictionary.pickle'),
                      'rb') as read_handle:
                test_dictionary = pickle.load(read_handle)

        # Check if dataset has multiple ground-truth labels
        mult_labels = "labels" in test_samples[0].keys()
        if params["filter_unlabeled"]:
            # Filter samples without gold entities
            test_samples = list(
                filter(
                    lambda sample: (len(sample["labels"]) > 0)
                    if mult_labels else (sample["label"] is not None),
                    test_samples))
        logger.info("Read %d test samples." % len(test_samples))

        mention_data, test_dictionary, test_tensor_data = data_process.process_mention_data(
            test_samples,
            test_dictionary,
            tokenizer,
            params["max_context_length"],
            params["max_cand_length"],
            multi_label_key="labels" if mult_labels else None,
            context_key=params["context_key"],
            silent=params["silent"],
            logger=logger,
            debug=params["debug"],
            knn=knn,
            dictionary_processed=entity_dictionary_loaded)
        print("Saving processed test data...")
        if not entity_dictionary_loaded:
            with open(test_dictionary_pkl_path, 'wb') as write_handle:
                pickle.dump(test_dictionary,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            entity_dictionary_loaded = True
        with open(test_tensor_data_pkl_path, 'wb') as write_handle:
            pickle.dump(test_tensor_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)
        with open(test_mention_data_pkl_path, 'wb') as write_handle:
            pickle.dump(mention_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

    # Store test dictionary token ids
    test_dict_vecs = torch.tensor(list(map(lambda x: x['ids'],
                                           test_dictionary)),
                                  dtype=torch.long)
    # Store test mention token ids
    test_men_vecs = test_tensor_data[:][0]

    n_entities = len(test_dict_vecs)
    n_mentions = len(test_tensor_data)

    if within_doc:
        if test_samples is None:
            test_samples, _ = read_data(data_split, params, logger)
        test_context_doc_ids = [s['context_doc_id'] for s in test_samples]

    if params["transductive"]:
        if os.path.isfile(train_tensor_data_pkl_path) and os.path.isfile(
                train_mention_data_pkl_path):
            print("Loading stored processed train data...")
            with open(train_tensor_data_pkl_path, 'rb') as read_handle:
                train_tensor_data = pickle.load(read_handle)
            with open(train_mention_data_pkl_path, 'rb') as read_handle:
                train_mention_data = pickle.load(read_handle)
        else:
            train_samples = utils.read_dataset('train', params["data_path"])

            # Check if dataset has multiple ground-truth labels
            mult_labels = "labels" in train_samples[0].keys()
            logger.info("Read %d test samples." % len(test_samples))

            train_mention_data, _, train_tensor_data = data_process.process_mention_data(
                train_samples,
                test_dictionary,
                tokenizer,
                params["max_context_length"],
                params["max_cand_length"],
                multi_label_key="labels" if mult_labels else None,
                context_key=params["context_key"],
                silent=params["silent"],
                logger=logger,
                debug=params["debug"],
                knn=knn,
                dictionary_processed=entity_dictionary_loaded)
            print("Saving processed train data...")
            with open(train_tensor_data_pkl_path, 'wb') as write_handle:
                pickle.dump(train_tensor_data,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            with open(train_mention_data_pkl_path, 'wb') as write_handle:
                pickle.dump(train_mention_data,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)

        # Store train mention token ids
        train_men_vecs = train_tensor_data[:][0]
        n_mentions += len(train_tensor_data)
        n_train_mentions = len(train_tensor_data)

    # if os.path.isfile(cands_path):
    #     print("Loading stored candidates...")
    #     with open(cands_path, 'rb') as read_handle:
    #         men_cands = pickle.load(read_handle)
    # else:
    # Check and load stored embedding data
    embed_data_path = os.path.join(embed_data_path, 'embed_data.t7')
    embed_data = None
    if os.path.isfile(embed_data_path):
        embed_data = torch.load(embed_data_path)

    if use_types:
        if embed_data is not None:
            logger.info('Loading stored embeddings and computing indexes')
            dict_embeds = embed_data['dict_embeds']
            if 'dict_idxs_by_type' in embed_data:
                dict_idxs_by_type = embed_data['dict_idxs_by_type']
            else:
                dict_idxs_by_type = data_process.get_idxs_by_type(
                    test_dictionary)
            dict_indexes = data_process.get_index_from_embeds(
                dict_embeds,
                dict_idxs_by_type,
                force_exact_search=params['force_exact_search'],
                probe_mult_factor=params['probe_mult_factor'])
            men_embeds = embed_data['men_embeds']
            if 'men_idxs_by_type' in embed_data:
                men_idxs_by_type = embed_data['men_idxs_by_type']
            else:
                men_idxs_by_type = data_process.get_idxs_by_type(mention_data)
            men_indexes = data_process.get_index_from_embeds(
                men_embeds,
                men_idxs_by_type,
                force_exact_search=params['force_exact_search'],
                probe_mult_factor=params['probe_mult_factor'])
        else:
            logger.info("Dictionary: Embedding and building index")
            dict_embeds, dict_indexes, dict_idxs_by_type = data_process.embed_and_index(
                reranker,
                test_dict_vecs,
                encoder_type="candidate",
                n_gpu=n_gpu,
                corpus=test_dictionary,
                force_exact_search=params['force_exact_search'],
                batch_size=params['embed_batch_size'],
                probe_mult_factor=params['probe_mult_factor'])
            logger.info("Queries: Embedding and building index")
            vecs = test_men_vecs
            men_data = mention_data
            if params['transductive']:
                vecs = torch.cat((train_men_vecs, vecs), dim=0)
                men_data = train_mention_data + mention_data
            men_embeds, men_indexes, men_idxs_by_type = data_process.embed_and_index(
                reranker,
                vecs,
                encoder_type="context",
                n_gpu=n_gpu,
                corpus=men_data,
                force_exact_search=params['force_exact_search'],
                batch_size=params['embed_batch_size'],
                probe_mult_factor=params['probe_mult_factor'])
    else:
        if embed_data is not None:
            logger.info('Loading stored embeddings and computing indexes')
            dict_embeds = embed_data['dict_embeds']
            dict_index = data_process.get_index_from_embeds(
                dict_embeds,
                force_exact_search=params['force_exact_search'],
                probe_mult_factor=params['probe_mult_factor'])
            men_embeds = embed_data['men_embeds']
            men_index = data_process.get_index_from_embeds(
                men_embeds,
                force_exact_search=params['force_exact_search'],
                probe_mult_factor=params['probe_mult_factor'])
        else:
            logger.info("Dictionary: Embedding and building index")
            dict_embeds, dict_index = data_process.embed_and_index(
                reranker,
                test_dict_vecs,
                'candidate',
                n_gpu=n_gpu,
                force_exact_search=params['force_exact_search'],
                batch_size=params['embed_batch_size'],
                probe_mult_factor=params['probe_mult_factor'])
            logger.info("Queries: Embedding and building index")
            vecs = test_men_vecs
            if params['transductive']:
                vecs = torch.cat((train_men_vecs, vecs), dim=0)
            men_embeds, men_index = data_process.embed_and_index(
                reranker,
                vecs,
                'context',
                n_gpu=n_gpu,
                force_exact_search=params['force_exact_search'],
                batch_size=params['embed_batch_size'],
                probe_mult_factor=params['probe_mult_factor'])

    # Save computed embedding data if not loaded from disk
    if embed_data is None:
        embed_data = {}
        embed_data['dict_embeds'] = dict_embeds
        embed_data['men_embeds'] = men_embeds
        if use_types:
            embed_data['dict_idxs_by_type'] = dict_idxs_by_type
            embed_data['men_idxs_by_type'] = men_idxs_by_type
        # NOTE: Cannot pickle faiss index because it is a SwigPyObject
        torch.save(embed_data,
                   embed_data_path,
                   pickle_protocol=pickle.HIGHEST_PROTOCOL)

    logger.info("Starting KNN search...")
    # Fetch (k+1) NN mention candidates
    if not use_types:
        _men_embeds = men_embeds
        if params['transductive']:
            _men_embeds = _men_embeds[n_train_mentions:]
        n_mens_to_fetch = len(_men_embeds) if within_doc else knn + 1
        nn_men_dists, nn_men_idxs = men_index.search(_men_embeds,
                                                     n_mens_to_fetch)
    else:
        query_len = len(men_embeds) - (n_train_mentions
                                       if params['transductive'] else 0)
        nn_men_idxs = -1 * np.ones((query_len, query_len), dtype=int)
        nn_men_dists = -1 * np.ones((query_len, query_len), dtype='float64')
        for entity_type in men_indexes:
            men_embeds_by_type = men_embeds[men_idxs_by_type[entity_type][
                men_idxs_by_type[entity_type] >=
                n_train_mentions]] if params['transductive'] else men_embeds[
                    men_idxs_by_type[entity_type]]
            n_mens_to_fetch = len(
                men_embeds_by_type) if within_doc else knn + 1
            nn_men_dists_by_type, nn_men_idxs_by_type = men_indexes[
                entity_type].search(
                    men_embeds_by_type,
                    min(n_mens_to_fetch, len(men_embeds_by_type)))
            nn_men_idxs_by_type = np.array(
                list(
                    map(lambda x: men_idxs_by_type[entity_type][x],
                        nn_men_idxs_by_type)))
            i = -1
            for idx in men_idxs_by_type[entity_type]:
                if params['transductive']:
                    idx -= n_train_mentions
                if idx < 0:
                    continue
                i += 1
                nn_men_idxs[idx][:len(nn_men_idxs_by_type[i]
                                      )] = nn_men_idxs_by_type[i]
                nn_men_dists[idx][:len(nn_men_dists_by_type[i]
                                       )] = nn_men_dists_by_type[i]
    logger.info("Search finished")

    logger.info(f'Calculating mention recall@{knn}')
    # Find the most similar entity and k-nn mentions for each mention query
    men_recall_knn = []
    men_cands = []
    cui_sums = defaultdict(int)
    for m in mention_data:
        cui_sums[m['label_cuis'][0]] += 1
    for idx in range(len(nn_men_idxs)):
        filter_mask_neg1 = nn_men_idxs[idx] != -1
        men_cand_idxs = nn_men_idxs[idx][filter_mask_neg1]
        men_cand_scores = nn_men_dists[idx][filter_mask_neg1]

        if within_doc:
            men_cand_idxs, wd_mask = filter_by_context_doc_id(
                men_cand_idxs,
                test_context_doc_ids[idx],
                test_context_doc_ids,
                return_numpy=True)
            men_cand_scores = men_cand_scores[wd_mask]

        # Filter candidates to remove mention query and keep only the top k candidates
        filter_mask = men_cand_idxs != idx
        men_cand_idxs, men_cand_scores = men_cand_idxs[
            filter_mask][:knn], men_cand_scores[filter_mask][:knn]
        assert len(men_cand_idxs) == knn
        men_cands.append(men_cand_idxs)
        # Calculate mention recall@k
        gold_cui = mention_data[idx]['label_cuis'][0]
        if cui_sums[gold_cui] > 1:
            recall_hit = [
                1. for midx in men_cand_idxs
                if mention_data[midx]['label_cuis'][0] == gold_cui
            ]
            recall_knn = sum(recall_hit) / min(cui_sums[gold_cui] - 1, knn)
            men_recall_knn.append(recall_knn)
    logger.info('Done')
    # assert len(men_recall_knn) == len(nn_men_idxs)

    # Pickle the graphs
    print(f"Saving top-{knn} candidates")
    with open(cands_path, 'wb') as write_handle:
        pickle.dump(men_cands, write_handle, protocol=pickle.HIGHEST_PROTOCOL)

    # Output final results
    mean_recall_knn = np.mean(men_recall_knn)
    logger.info(f"Result: Final mean recall@{knn} = {mean_recall_knn}")
Ejemplo n.º 6
0
def main(params):
    # Parameter initializations
    logger = utils.get_logger(params["output_path"])
    global SCORING_BATCH_SIZE
    SCORING_BATCH_SIZE = params["scoring_batch_size"]
    output_path = params["output_path"]
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    pickle_src_path = params["pickle_src_path"]
    if pickle_src_path is None or not os.path.exists(pickle_src_path):
        pickle_src_path = output_path
    biencoder_indices_path = params["biencoder_indices_path"]
    if biencoder_indices_path is None:
        biencoder_indices_path = output_path
    elif not os.path.exists(biencoder_indices_path):
        os.makedirs(biencoder_indices_path)
    max_k = params["knn"]  # Maximum k-NN graph to build for evaluation
    use_types = params["use_types"]
    within_doc = params["within_doc"]
    discovery_mode = params["discovery"]

    # Bi-encoder model
    biencoder_params = copy.deepcopy(params)
    biencoder_params['add_linear'] = False
    bi_reranker = BiEncoderRanker(biencoder_params)
    bi_tokenizer = bi_reranker.tokenizer
    k_biencoder = params[
        "bi_knn"]  # Number of biencoder nearest-neighbors to fetch for cross-encoder scoring (default: 64)

    # Cross-encoder model
    params['add_linear'] = True
    params['add_sigmoid'] = True
    cross_reranker = CrossEncoderRanker(params)
    n_gpu = cross_reranker.n_gpu
    cross_reranker.model.eval()

    # Input lengths
    max_seq_length = params["max_seq_length"]
    max_context_length = params["max_context_length"]
    max_cand_length = params["max_cand_length"]

    # Fix random seeds
    seed = params["seed"]
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if cross_reranker.n_gpu > 0:
        torch.cuda.manual_seed_all(seed)

    # The below code is to generate the candidates for cross-encoder training and inference
    if params["save_topk_result"]:
        save_topk_biencoder_cands(bi_reranker,
                                  use_types,
                                  logger,
                                  n_gpu,
                                  params,
                                  bi_tokenizer,
                                  max_context_length,
                                  max_cand_length,
                                  pickle_src_path,
                                  topk=64)
        exit()

    data_split = params["data_split"]
    entity_dictionary, tensor_data, processed_data = load_data(
        data_split, bi_tokenizer, max_context_length, max_cand_length, max_k,
        pickle_src_path, params, logger)
    n_entities = len(entity_dictionary)
    n_mentions = len(processed_data)
    # Store dictionary vectors
    dict_vecs = torch.tensor(list(map(lambda x: x['ids'], entity_dictionary)),
                             dtype=torch.long)
    # Store query vectors
    men_vecs = tensor_data[:][0]

    discovery_entities = []
    if discovery_mode:
        discovery_entities, n_ents_dropped, n_mentions_wo_gold_ents, mention_gold_entities = get_entity_idxs_to_drop(
            processed_data, params, logger)

    context_doc_ids = None
    if within_doc:
        # Get context_document_ids for each mention in training and validation
        context_doc_ids = get_context_doc_ids(data_split, params)
    params[
        "only_evaluate"] = True  # Needed to call get_biencoder_nns() correctly
    _, biencoder_nns = get_biencoder_nns(
        bi_reranker=bi_reranker,
        biencoder_indices_path=biencoder_indices_path,
        entity_dictionary=entity_dictionary,
        entity_dict_vecs=dict_vecs,
        train_men_vecs=None,
        train_processed_data=None,
        train_gold_clusters=None,
        valid_men_vecs=men_vecs,
        valid_processed_data=processed_data,
        use_types=use_types,
        logger=logger,
        n_gpu=n_gpu,
        params=params,
        train_context_doc_ids=None,
        valid_context_doc_ids=context_doc_ids)
    bi_men_idxs = biencoder_nns['men_nns'][:, :k_biencoder]
    bi_ent_idxs = biencoder_nns['dict_nns'][:, :k_biencoder]
    bi_nn_count = np.sum(biencoder_nns['men_nns'] != -1, axis=1)

    # Compute and store the concatenated cross-encoder inputs for validation
    men_concat_inputs, ent_concat_inputs = build_cross_concat_input(
        biencoder_nns, men_vecs, dict_vecs, max_seq_length, k_biencoder)

    # Values of k to run the evaluation against
    knn_vals = [0] + [2**i for i in range(int(math.log(max_k, 2)) + 1)]
    max_k = knn_vals[-1]  # Store the maximum evaluation k
    bi_recall = None

    time_start = time.time()
    # Check if k-NN graphs are already built
    graph_path = os.path.join(output_path, 'graphs.pickle')
    if not params['only_recall'] and os.path.isfile(graph_path):
        print("Loading stored joint graphs...")
        with open(graph_path, 'rb') as read_handle:
            joint_graphs = pickle.load(read_handle)
    else:
        joint_graphs = {}
        for k in knn_vals:
            joint_graphs[k] = {
                'rows': np.array([]),
                'cols': np.array([]),
                'data': np.array([]),
                'shape': (n_entities + n_mentions, n_entities + n_mentions)
            }

        # Score biencoder NNs using cross-encoder
        score_path = os.path.join(output_path, 'cross_scores_indexes.pickle')
        if os.path.isfile(score_path):
            print("Loading stored cross-encoder scores and indexes...")
            with open(score_path, 'rb') as read_handle:
                score_data = pickle.load(read_handle)
            cross_men_topk_idxs = score_data['cross_men_topk_idxs']
            cross_men_topk_scores = score_data['cross_men_topk_scores']
            cross_ent_top1_idx = score_data['cross_ent_top1_idx']
            cross_ent_top1_score = score_data['cross_ent_top1_score']
        else:
            with torch.no_grad():
                logger.info(
                    'Eval: Scoring mention-mention edges using cross-encoder...'
                )
                cross_men_scores = score_in_batches(
                    cross_reranker,
                    max_context_length,
                    men_concat_inputs,
                    is_context_encoder=True,
                    scoring_batch_size=SCORING_BATCH_SIZE)
                for i in range(len(cross_men_scores)):
                    # Set scores for all invalid nearest neighbours to -infinity (due to variable NN counts of mentions)
                    cross_men_scores[i][bi_nn_count[i]:] = float('-inf')
                cross_men_topk_scores, cross_men_topk_idxs = torch.sort(
                    cross_men_scores, dim=1, descending=True)
                cross_men_topk_idxs = cross_men_topk_idxs.cpu()[:, :max_k]
                cross_men_topk_scores = cross_men_topk_scores.cpu()[:, :max_k]
                logger.info('Eval: Scoring done')

                logger.info(
                    'Eval: Scoring mention-entity edges using cross-encoder...'
                )
                cross_ent_scores = score_in_batches(
                    cross_reranker,
                    max_context_length,
                    ent_concat_inputs,
                    is_context_encoder=False,
                    scoring_batch_size=SCORING_BATCH_SIZE)
                cross_ent_top1_score, cross_ent_top1_idx = torch.sort(
                    cross_ent_scores, dim=1, descending=True)
                cross_ent_top1_idx = cross_ent_top1_idx.cpu()
                cross_ent_top1_score = cross_ent_top1_score.cpu()
                if discovery_mode:
                    # Replace the first value in each row with an entity not in the drop set
                    for i in range(cross_ent_top1_idx.shape[0]):
                        for j in range(cross_ent_top1_idx.shape[1]):
                            if cross_ent_top1_idx[i,
                                                  j] not in discovery_entities:
                                cross_ent_top1_idx[i,
                                                   0] = cross_ent_top1_idx[i,
                                                                           j]
                                cross_ent_top1_score[
                                    i, 0] = cross_ent_top1_score[i, j]
                                break
                cross_ent_top1_idx = cross_ent_top1_idx[:, 0]
                cross_ent_top1_score = cross_ent_top1_score[:, 0]
                logger.info('Eval: Scoring done')
            # Pickle the scores and nearest indexes
            logger.info("Saving cross-encoder scores and indexes...")
            with open(score_path, 'wb') as write_handle:
                pickle.dump(
                    {
                        'cross_men_topk_idxs': cross_men_topk_idxs,
                        'cross_men_topk_scores': cross_men_topk_scores,
                        'cross_ent_top1_idx': cross_ent_top1_idx,
                        'cross_ent_top1_score': cross_ent_top1_score
                    },
                    write_handle,
                    protocol=pickle.HIGHEST_PROTOCOL)
            logger.info(f"Saved at: {score_path}")

        # Build k-NN graphs
        bi_recall = 0.
        for men_idx in tqdm(range(len(processed_data)),
                            total=len(processed_data),
                            desc="Eval: Building graphs"):
            # Track biencoder recall@<k_biencoder>
            gold_idx = processed_data[men_idx]["label_idxs"][0]
            if gold_idx in bi_ent_idxs[men_idx]:
                bi_recall += 1.
            # Get nearest entity
            m_e_idx = bi_ent_idxs[men_idx, cross_ent_top1_idx[men_idx]]
            m_e_score = cross_ent_top1_score[men_idx]
            if bi_nn_count[men_idx] > 0:
                # Get nearest mentions
                topk_defined_nn_idxs = cross_men_topk_idxs[
                    men_idx][:bi_nn_count[men_idx]]
                m_m_idxs = bi_men_idxs[
                    men_idx,
                    topk_defined_nn_idxs] + n_entities  # Mentions added at an offset of maximum entities
                m_m_scores = cross_men_topk_scores[
                    men_idx][:bi_nn_count[men_idx]]
            # Add edges to the graphs
            for k in joint_graphs:
                # Add mention-entity edge
                joint_graphs[k]['rows'] = np.append(
                    joint_graphs[k]['rows'],
                    [n_entities + men_idx
                     ])  # Mentions added at an offset of maximum entities
                joint_graphs[k]['cols'] = np.append(joint_graphs[k]['cols'],
                                                    m_e_idx)
                joint_graphs[k]['data'] = np.append(joint_graphs[k]['data'],
                                                    m_e_score)
                if k > 0 and bi_nn_count[men_idx] > 0:
                    # Add mention-mention edges
                    joint_graphs[k]['rows'] = np.append(
                        joint_graphs[k]['rows'],
                        [n_entities + men_idx] * len(m_m_idxs[:k]))
                    joint_graphs[k]['cols'] = np.append(
                        joint_graphs[k]['cols'], m_m_idxs[:k])
                    joint_graphs[k]['data'] = np.append(
                        joint_graphs[k]['data'], m_m_scores[:k])
        # Pickle the graphs
        print("Saving joint graphs...")
        with open(graph_path, 'wb') as write_handle:
            pickle.dump(joint_graphs,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)
        logger.info(f"Saved graphs at: {graph_path}")
        # Compute biencoder recall
        bi_recall /= len(processed_data)
        if params['only_recall']:
            logger.info(
                f"Eval: Biencoder recall@{k_biencoder} = {bi_recall * 100}%")
            exit()

    graph_mode = params.get('graph_mode', None)

    if discovery_mode:
        # Run the entity discovery experiment
        results = {
            'data_split': data_split.upper(),
            'n_entities': n_entities,
            'n_mentions': n_mentions,
            'n_entities_dropped':
            f"{n_ents_dropped} ({params['ent_drop_prop'] * 100}%)",
            'n_mentions_wo_gold_entities': n_mentions_wo_gold_ents
        }
        run_discovery_experiment(joint_graphs, discovery_entities,
                                 mention_gold_entities, n_entities, n_mentions,
                                 data_split, results, output_path, time_start,
                                 params, logger)
    else:
        # Run entity linking inference
        result_overview, results = {
            'n_entities': n_entities,
            'n_mentions': n_mentions
        }, {}
        if graph_mode is None or graph_mode not in ['directed', 'undirected']:
            results['directed'], results['undirected'] = [], []
        else:
            results[graph_mode] = []
        run_inference(entity_dictionary, processed_data, results,
                      result_overview, joint_graphs, n_entities, time_start,
                      output_path, bi_recall, k_biencoder, logger)
Ejemplo n.º 7
0
def main(params):
    output_path = params["output_path"]
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = utils.get_logger(params["output_path"], 'log')

    pickle_src_path = params["pickle_src_path"]
    if pickle_src_path is None or not os.path.exists(pickle_src_path):
        pickle_src_path = output_path

    embed_data_path = params["embed_data_path"]
    if embed_data_path is None or not os.path.exists(embed_data_path):
        embed_data_path = output_path

    # Init model
    reranker = BiEncoderRanker(params)
    reranker.model.eval()
    tokenizer = reranker.tokenizer
    n_gpu = reranker.n_gpu

    knn = params["knn"]  # Use as the max-knn value for the graph construction
    use_types = params["use_types"]
    # within_doc = params["within_doc"]
    data_split = params["data_split"]  # Default = "test"

    # Load test data
    entity_dictionary_loaded = False
    test_dictionary_pkl_path = os.path.join(pickle_src_path,
                                            'test_dictionary.pickle')
    test_tensor_data_pkl_path = os.path.join(pickle_src_path,
                                             'test_tensor_data.pickle')
    test_mention_data_pkl_path = os.path.join(pickle_src_path,
                                              'test_mention_data.pickle')
    # if params['transductive']:
    #     train_tensor_data_pkl_path = os.path.join(pickle_src_path, 'train_tensor_data.pickle')
    #     train_mention_data_pkl_path = os.path.join(pickle_src_path, 'train_mention_data.pickle')
    if os.path.isfile(test_dictionary_pkl_path):
        print("Loading stored processed entity dictionary...")
        with open(test_dictionary_pkl_path, 'rb') as read_handle:
            test_dictionary = pickle.load(read_handle)
        entity_dictionary_loaded = True
    if os.path.isfile(test_tensor_data_pkl_path) and os.path.isfile(
            test_mention_data_pkl_path):
        print("Loading stored processed test data...")
        with open(test_tensor_data_pkl_path, 'rb') as read_handle:
            test_tensor_data = pickle.load(read_handle)
        with open(test_mention_data_pkl_path, 'rb') as read_handle:
            mention_data = pickle.load(read_handle)
    else:
        test_samples = utils.read_dataset(data_split, params["data_path"])
        if not entity_dictionary_loaded:
            with open(os.path.join(params["data_path"], 'dictionary.pickle'),
                      'rb') as read_handle:
                test_dictionary = pickle.load(read_handle)

        # Check if dataset has multiple ground-truth labels
        mult_labels = "labels" in test_samples[0].keys()
        # Filter samples without gold entities
        test_samples = list(
            filter(
                lambda sample: (len(sample["labels"]) > 0) if mult_labels else
                (sample["label"] is not None), test_samples))
        logger.info("Read %d test samples." % len(test_samples))

        mention_data, test_dictionary, test_tensor_data = data_process.process_mention_data(
            test_samples,
            test_dictionary,
            tokenizer,
            params["max_context_length"],
            params["max_cand_length"],
            multi_label_key="labels" if mult_labels else None,
            context_key=params["context_key"],
            silent=params["silent"],
            logger=logger,
            debug=params["debug"],
            knn=knn,
            dictionary_processed=entity_dictionary_loaded)
        print("Saving processed test data...")
        if not entity_dictionary_loaded:
            with open(test_dictionary_pkl_path, 'wb') as write_handle:
                pickle.dump(test_dictionary,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
        with open(test_tensor_data_pkl_path, 'wb') as write_handle:
            pickle.dump(test_tensor_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)
        with open(test_mention_data_pkl_path, 'wb') as write_handle:
            pickle.dump(mention_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

    # Reducing the entity dictionary to only the ground truth of the mention queries
    # Combining the entities and mentions into one structure for joint embedding and indexing
    new_ents = {}
    new_ents_arr = []
    men_labels = []
    for men in mention_data:
        ent = men['label_idxs'][0]
        if ent not in new_ents:
            new_ents[ent] = len(new_ents_arr)
            new_ents_arr.append(ent)
        men_labels.append(new_ents[ent])
    ent_labels = [i for i in range(len(new_ents_arr))]
    new_ent_vecs = torch.tensor(
        list(map(lambda x: test_dictionary[x]['ids'], new_ents_arr)))
    new_ent_types = list(
        map(lambda x: {"type": test_dictionary[x]['type']}, new_ents_arr))
    test_men_vecs = test_tensor_data[:][0]

    n_mentions = len(test_tensor_data)
    n_entities = len(new_ent_vecs)
    n_embeds = n_mentions + n_entities
    leaf_labels = np.array(ent_labels + men_labels, dtype=int)
    all_vecs = torch.cat((new_ent_vecs, test_men_vecs))
    all_types = new_ent_types + mention_data  # Array of dicts containing key "type" for selected ents and all mentions

    # Values of k to run the evaluation against
    knn_vals = [25 * 2**i for i in range(int(math.log(knn / 25, 2)) + 1)
                ] if params["exact_knn"] is None else [params["exact_knn"]]
    # Store the maximum evaluation k
    max_knn = knn_vals[-1]

    time_start = time.time()

    # Check if graphs are already built
    graph_path = os.path.join(output_path, 'graphs.pickle')
    if os.path.isfile(graph_path):
        print("Loading stored joint graphs...")
        with open(graph_path, 'rb') as read_handle:
            joint_graphs = pickle.load(read_handle)
    else:
        # Initialize graphs to store mention-mention and mention-entity similarity score edges;
        # Keyed on k, the number of nearest mentions retrieved
        joint_graphs = {}
        for k in knn_vals:
            joint_graphs[k] = {
                'rows': np.array([]),
                'cols': np.array([]),
                'data': np.array([]),
                'shape': (n_embeds, n_embeds)
            }

        # Check and load stored embedding data
        embed_data_path = os.path.join(embed_data_path, 'embed_data.t7')
        embed_data = None
        if os.path.isfile(embed_data_path):
            embed_data = torch.load(embed_data_path)
        if use_types:
            if embed_data is not None:
                logger.info('Loading stored embeddings')
                embeds = embed_data['embeds']
                if 'idxs_by_type' in embed_data:
                    idxs_by_type = embed_data['idxs_by_type']
                else:
                    idxs_by_type = data_process.get_idxs_by_type(all_types)
            else:
                logger.info("Embedding data")
                dict_embeds = data_process.embed_and_index(
                    reranker,
                    all_vecs[:n_entities],
                    encoder_type='candidate',
                    only_embed=True,
                    n_gpu=n_gpu,
                    batch_size=params['embed_batch_size'])
                men_embeds = data_process.embed_and_index(
                    reranker,
                    all_vecs[n_entities:],
                    encoder_type='context',
                    only_embed=True,
                    n_gpu=n_gpu,
                    batch_size=params['embed_batch_size'])
                embeds = np.concatenate((dict_embeds, men_embeds), axis=0)
                idxs_by_type = data_process.get_idxs_by_type(all_types)
            search_indexes = data_process.get_index_from_embeds(
                embeds, corpus_idxs=idxs_by_type, force_exact_search=True)
        else:
            if embed_data is not None:
                logger.info('Loading stored embeddings')
                embeds = embed_data['embeds']
            else:
                logger.info("Embedding data")
                dict_embeds = data_process.embed_and_index(
                    reranker,
                    all_vecs[:n_entities],
                    encoder_type='candidate',
                    only_embed=True,
                    n_gpu=n_gpu,
                    batch_size=params['embed_batch_size'])
                men_embeds = data_process.embed_and_index(
                    reranker,
                    all_vecs[n_entities:],
                    encoder_type='context',
                    only_embed=True,
                    n_gpu=n_gpu,
                    batch_size=params['embed_batch_size'])
                embeds = np.concatenate((dict_embeds, men_embeds), axis=0)
            search_index = data_process.get_index_from_embeds(
                embeds, force_exact_search=True)
        # Save computed embedding data if not loaded from disk
        if embed_data is None:
            embed_data = {}
            embed_data['embeds'] = embeds
            if use_types:
                embed_data['idxs_by_type'] = idxs_by_type
            # NOTE: Cannot pickle faiss index because it is a SwigPyObject
            torch.save(embed_data,
                       embed_data_path,
                       pickle_protocol=pickle.HIGHEST_PROTOCOL)

        # Build faiss search index
        if params["normalize_embeds"]:
            embeds = normalize(embeds, axis=0)
        logger.info("Building KNN index...")
        if use_types:
            search_indexes = data_process.get_index_from_embeds(
                embeds, corpus_idxs=idxs_by_type, force_exact_search=True)
        else:
            search_index = data_process.get_index_from_embeds(
                embeds, force_exact_search=True)

        logger.info("Starting KNN search...")
        if not use_types:
            faiss_dists, faiss_idxs = search_index.search(embeds, max_knn + 1)
        else:
            query_len = n_embeds
            faiss_idxs = np.zeros((query_len, max_knn + 1))
            faiss_dists = np.zeros((query_len, max_knn + 1), dtype=float)
            for entity_type in search_indexes:
                embeds_by_type = embeds[idxs_by_type[entity_type]]
                nn_dists_by_type, nn_idxs_by_type = search_indexes[
                    entity_type].search(embeds_by_type, max_knn + 1)
                for i, idx in enumerate(idxs_by_type[entity_type]):
                    faiss_idxs[idx] = nn_idxs_by_type[i]
                    faiss_dists[idx] = nn_dists_by_type[i]
        logger.info("Search finished")

        logger.info('Building graphs')
        # Find the most similar nodes for each mention and node in the set (minus self)
        for idx in trange(n_embeds):
            # Compute adjacent node edge weight
            if idx != 0:
                adj_idx = idx - 1
                adj_data = embeds[adj_idx] @ embeds[idx]
            nn_idxs = faiss_idxs[idx]
            nn_scores = faiss_dists[idx]
            # Filter candidates to remove mention query and keep only the top k candidates
            filter_mask = nn_idxs != idx
            nn_idxs, nn_scores = nn_idxs[filter_mask][:max_knn], nn_scores[
                filter_mask][:max_knn]
            # Add edges to the graphs
            for k in joint_graphs:
                # Add edge to adjacent node to force the graph to be connected
                if idx != 0:
                    joint_graph['rows'] = np.append(joint_graph['rows'],
                                                    adj_idx)
                    joint_graph['cols'] = np.append(joint_graph['cols'], idx)
                    joint_graph['data'] = np.append(joint_graph['data'],
                                                    adj_data)
                joint_graph = joint_graphs[k]
                # Add mention-mention edges
                joint_graph['rows'] = np.append(joint_graph['rows'], [idx] * k)
                joint_graph['cols'] = np.append(joint_graph['cols'],
                                                nn_idxs[:k])
                joint_graph['data'] = np.append(joint_graph['data'],
                                                nn_scores[:k])

        knn_fetch_time = time.time() - time_start
        # Pickle the graphs
        print("Saving joint graphs...")
        with open(graph_path, 'wb') as write_handle:
            pickle.dump(joint_graphs,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

        if params['only_embed_and_build']:
            logger.info(f"Saved embedding data at: {embed_data_path}")
            logger.info(f"Saved graphs at: {graph_path}")
            exit()

    results = {
        'n_leaves': n_embeds,
        'n_entities': n_entities,
        'n_mentions': n_mentions
    }

    graph_processing_time = time.time()
    n_graphs_processed = 0.
    linkage_fns = ["single", "complete", "average"] if params["linkage"] is None \
        else [params["linkage"]]  # Different HAC linkage functions to run the analyses over

    for fn in linkage_fns:
        logger.info(f"Linkage function: {fn}")
        purities = []
        fn_result = {}
        for k in joint_graphs:
            graph = hg.UndirectedGraph(n_embeds)
            graph.add_edges(joint_graphs[k]['rows'], joint_graphs[k]['cols'])
            weights = -joint_graphs[k][
                'data']  # Since Higra expects weights as distances, not similarity
            tree = get_hac_tree(graph, weights, linkage=fn)
            purity = hg.dendrogram_purity(tree, leaf_labels)
            fn_result[f"purity@{k}nn"] = purity
            logger.info(f"purity@{k}nn = {purity}")
            purities.append(purity)
            n_graphs_processed += 1
        fn_result["average"] = round(np.mean(purities), 4)
        logger.info(f"average = {fn_result['average']}")
        results[fn] = fn_result

    avg_graph_processing_time = (time.time() -
                                 graph_processing_time) / n_graphs_processed
    avg_per_graph_time = (knn_fetch_time + avg_graph_processing_time) / 60
    execution_time = (time.time() - time_start) / 60

    # Store results
    output_file_name = os.path.join(
        output_path,
        f"results_{__import__('calendar').timegm(__import__('time').gmtime())}"
    )

    logger.info(f"Results: \n {results}")
    with open(f'{output_file_name}.json', 'w') as f:
        json.dump(results, f, indent=2)
        print(f"\nResults saved at: {output_file_name}.json")

    logger.info("\nThe avg. per graph evaluation time is {} minutes\n".format(
        avg_per_graph_time))
    logger.info(
        "\nThe total evaluation took {} minutes\n".format(execution_time))
Ejemplo n.º 8
0
def main(params):
    model_output_path = params["output_path"]
    if not os.path.exists(model_output_path):
        os.makedirs(model_output_path)
    logger = utils.get_logger(params["output_path"])

    # Init model
    reranker = BiEncoderRanker(params)
    tokenizer = reranker.tokenizer
    model = reranker.model

    # utils.save_model(model, tokenizer, model_output_path)

    device = reranker.device
    n_gpu = reranker.n_gpu

    if params["gradient_accumulation_steps"] < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(params["gradient_accumulation_steps"]))

    # An effective batch size of `x`, when we are accumulating the gradient accross `y` batches will be achieved by having a batch size of `z = x / y`
    # args.gradient_accumulation_steps = args.gradient_accumulation_steps // n_gpu
    params["train_batch_size"] = (params["train_batch_size"] //
                                  params["gradient_accumulation_steps"])
    train_batch_size = params["train_batch_size"]
    eval_batch_size = params["eval_batch_size"]
    grad_acc_steps = params["gradient_accumulation_steps"]

    # Fix the random seeds
    seed = params["seed"]
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if reranker.n_gpu > 0:
        torch.cuda.manual_seed_all(seed)

    # Load train data
    train_samples = utils.read_dataset("train", params["data_path"])
    logger.info("Read %d train samples." % len(train_samples))

    train_data, train_tensor_data = data.process_mention_data(
        train_samples,
        tokenizer,
        params["max_context_length"],
        params["max_cand_length"],
        context_key=params["context_key"],
        silent=params["silent"],
        logger=logger,
        debug=params["debug"],
    )
    if params["shuffle"]:
        train_sampler = RandomSampler(train_tensor_data)
    else:
        train_sampler = SequentialSampler(train_tensor_data)

    train_dataloader = DataLoader(train_tensor_data,
                                  sampler=train_sampler,
                                  batch_size=train_batch_size)

    # Load eval data
    # TODO: reduce duplicated code here
    valid_samples = utils.read_dataset("valid", params["data_path"])
    logger.info("Read %d valid samples." % len(valid_samples))

    valid_data, valid_tensor_data = data.process_mention_data(
        valid_samples,
        tokenizer,
        params["max_context_length"],
        params["max_cand_length"],
        context_key=params["context_key"],
        silent=params["silent"],
        logger=logger,
        debug=params["debug"],
    )
    valid_sampler = SequentialSampler(valid_tensor_data)
    valid_dataloader = DataLoader(valid_tensor_data,
                                  sampler=valid_sampler,
                                  batch_size=eval_batch_size)

    # evaluate before training
    results = evaluate(
        reranker,
        valid_dataloader,
        params,
        device=device,
        logger=logger,
    )

    number_of_samples_per_dataset = {}

    time_start = time.time()

    utils.write_to_file(os.path.join(model_output_path, "training_params.txt"),
                        str(params))

    logger.info("Starting training")
    logger.info("device: {} n_gpu: {}, distributed training: {}".format(
        device, n_gpu, False))

    optimizer = get_optimizer(model, params)
    scheduler = get_scheduler(params, optimizer, len(train_tensor_data),
                              logger)

    model.train()

    best_epoch_idx = -1
    best_score = -1

    num_train_epochs = params["num_train_epochs"]
    for epoch_idx in trange(int(num_train_epochs), desc="Epoch"):
        tr_loss = 0
        results = None

        if params["silent"]:
            iter_ = train_dataloader
        else:
            iter_ = tqdm(train_dataloader, desc="Batch")

        for step, batch in enumerate(iter_):
            batch = tuple(t.to(device) for t in batch)
            if params["zeshel"]:
                context_input, candidate_input, _, _ = batch
            else:
                context_input, candidate_input, _ = batch
            loss, _ = reranker(context_input, candidate_input)

            # if n_gpu > 1:
            #     loss = loss.mean() # mean() to average on multi-gpu.

            if grad_acc_steps > 1:
                loss = loss / grad_acc_steps

            tr_loss += loss.item()

            if (step + 1) % (params["print_interval"] * grad_acc_steps) == 0:
                logger.info("Step {} - epoch {} average loss: {}\n".format(
                    step,
                    epoch_idx,
                    tr_loss / (params["print_interval"] * grad_acc_steps),
                ))
                tr_loss = 0

            loss.backward()

            if (step + 1) % grad_acc_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               params["max_grad_norm"])
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

            if (step + 1) % (params["eval_interval"] * grad_acc_steps) == 0:
                logger.info("Evaluation on the development dataset")
                evaluate(
                    reranker,
                    valid_dataloader,
                    params,
                    device=device,
                    logger=logger,
                )
                model.train()
                logger.info("\n")

        logger.info("***** Saving fine - tuned model *****")
        epoch_output_folder_path = os.path.join(model_output_path,
                                                "epoch_{}".format(epoch_idx))
        utils.save_model(model, tokenizer, epoch_output_folder_path)

        output_eval_file = os.path.join(epoch_output_folder_path,
                                        "eval_results.txt")
        results = evaluate(
            reranker,
            valid_dataloader,
            params,
            device=device,
            logger=logger,
        )

        ls = [best_score, results["normalized_accuracy"]]
        li = [best_epoch_idx, epoch_idx]

        best_score = ls[np.argmax(ls)]
        best_epoch_idx = li[np.argmax(ls)]
        logger.info("\n")

    execution_time = (time.time() - time_start) / 60
    utils.write_to_file(
        os.path.join(model_output_path, "training_time.txt"),
        "The training took {} minutes\n".format(execution_time),
    )
    logger.info("The training took {} minutes\n".format(execution_time))

    # save the best model in the parent_dir
    logger.info("Best performance in epoch: {}".format(best_epoch_idx))
    params["path_to_model"] = os.path.join(model_output_path,
                                           "epoch_{}".format(best_epoch_idx))
    utils.save_model(reranker.model, tokenizer, model_output_path)

    if params["evaluate"]:
        params["path_to_model"] = model_output_path
        results = evaluate(
            reranker,
            valid_dataloader,
            params,
            device=device,
            logger=logger,
        )
Ejemplo n.º 9
0
def main(params):
    output_path = params["output_path"]
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = utils.get_logger(params["output_path"])

    pickle_src_path = params["pickle_src_path"]
    if pickle_src_path is None or not os.path.exists(pickle_src_path):
        pickle_src_path = output_path

    # Init model
    reranker = BiEncoderRanker(params)
    reranker.model.eval()
    tokenizer = reranker.tokenizer
    model = reranker.model
    device = reranker.device
    n_gpu = reranker.n_gpu

    knn = params["knn"]
    directed_graph = params["directed_graph"]
    use_types = params["use_types"]
    data_split = params["data_split"] # Parameter default is "test"

    # Load test data
    entity_dictionary_loaded = False
    test_dictionary_pkl_path = os.path.join(pickle_src_path, 'test_dictionary.pickle')
    test_tensor_data_pkl_path = os.path.join(pickle_src_path, 'test_tensor_data.pickle')
    test_mention_data_pkl_path = os.path.join(pickle_src_path, 'test_mention_data.pickle')
    if os.path.isfile(test_dictionary_pkl_path):
        print("Loading stored processed entity dictionary...")
        with open(test_dictionary_pkl_path, 'rb') as read_handle:
            test_dictionary = pickle.load(read_handle)
        entity_dictionary_loaded = True
    if os.path.isfile(test_tensor_data_pkl_path) and os.path.isfile(test_mention_data_pkl_path):
        print("Loading stored processed test data...")
        with open(test_tensor_data_pkl_path, 'rb') as read_handle:
            test_tensor_data = pickle.load(read_handle)
        with open(test_mention_data_pkl_path, 'rb') as read_handle:
            mention_data = pickle.load(read_handle)
    else:
        test_samples = utils.read_dataset(data_split, params["data_path"])
        if not entity_dictionary_loaded:
            with open(os.path.join(params["data_path"], 'dictionary.pickle'), 'rb') as read_handle:
                test_dictionary = pickle.load(read_handle)

        # Check if dataset has multiple ground-truth labels
        mult_labels = "labels" in test_samples[0].keys()
        if params["filter_unlabeled"]:
            # Filter samples without gold entities
            test_samples = list(filter(lambda sample: (len(sample["labels"]) > 0) if mult_labels else (sample["label"] is not None), test_samples))
        logger.info("Read %d test samples." % len(test_samples))

        mention_data, test_dictionary, test_tensor_data = data_process.process_mention_data(
            test_samples,
            test_dictionary,
            tokenizer,
            params["max_context_length"],
            params["max_cand_length"],
            multi_label_key="labels" if mult_labels else None,
            context_key=params["context_key"],
            silent=params["silent"],
            logger=logger,
            debug=params["debug"],
            knn=knn,
            dictionary_processed=entity_dictionary_loaded
        )
        print("Saving processed test data...")
        if not entity_dictionary_loaded:
            with open(test_dictionary_pkl_path, 'wb') as write_handle:
                pickle.dump(test_dictionary, write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
        with open(test_tensor_data_pkl_path, 'wb') as write_handle:
            pickle.dump(test_tensor_data, write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)
        with open(test_mention_data_pkl_path, 'wb') as write_handle:
            pickle.dump(mention_data, write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

    # Store test dictionary token ids
    test_dict_vecs = torch.tensor(
        list(map(lambda x: x['ids'], test_dictionary)), dtype=torch.long)
    # Store test mention token ids
    test_men_vecs = test_tensor_data[:][0]

    n_entities = len(test_dict_vecs)
    n_mentions = len(test_tensor_data)

    # Values of k to run the evaluation against
    knn_vals = [0] + [2**i for i in range(int(math.log(knn, 2)) + 1)]
    # Store the maximum evaluation k
    max_knn = knn_vals[-1]

    # Check if graphs are already built
    graph_path = os.path.join(output_path, 'graphs.pickle')
    if not params['only_recall'] and os.path.isfile(graph_path):
        print("Loading stored joint graphs...")
        with open(graph_path, 'rb') as read_handle:
            joint_graphs = pickle.load(read_handle)
    else:
        # Initialize graphs to store mention-mention and mention-entity similarity score edges;
        # Keyed on k, the number of nearest mentions retrieved
        joint_graphs = {}
        for k in knn_vals:
            joint_graphs[k] = {
                'rows': np.array([]),
                'cols': np.array([]),
                'data': np.array([]),
                'shape': (n_entities+n_mentions, n_entities+n_mentions)
            }

        if use_types:
            print("Dictionary: Embedding and building index")
            dict_embeds, dict_indexes, dict_idxs_by_type = data_process.embed_and_index(reranker, test_dict_vecs, encoder_type="candidate", n_gpu=n_gpu, corpus=test_dictionary, force_exact_search=True)

            # Verifiy embeddings
            og_embeds = torch.load('models/trained/zeshel_og/eval/data_og/cand_encodes.t7')
            world_to_type = {12:'forgotten_realms', 13:'lego', 14:'star_trek', 15:'yugioh'}
            for world in og_embeds:
                for i,oge in enumerate(tqdm(og_embeds[world])):
                    dict_embed_idx = dict_idxs_by_type[world_to_type[world]][i]
                    try:
                        assert torch.eq(oge, torch.tensor(dict_embeds[dict_embed_idx]))
                    except:
                        embed()
                        exit()
            print('PASS')
            exit()
            print("Queries: Embedding and building index")
            men_embeds, men_indexes, men_idxs_by_type = data_process.embed_and_index(reranker, test_men_vecs, encoder_type="context", n_gpu=n_gpu, corpus=mention_data, force_exact_search=True)
        else:
            print("Dictionary: Embedding and building index")
            dict_embeds, dict_index = data_process.embed_and_index(
                reranker, test_dict_vecs, 'candidate', n_gpu=n_gpu)
            print("Queries: Embedding and building index")
            men_embeds, men_index = data_process.embed_and_index(
                reranker, test_men_vecs, 'context', n_gpu=n_gpu)

        recall_accuracy = {2**i: 0 for i in range(int(math.log(params['recall_k'], 2)) + 1)}
        recall_idxs = [0.]*params['recall_k']

        # Find the most similar entity and k-nn mentions for each mention query
        for men_query_idx, men_embed in enumerate(tqdm(men_embeds, total=len(men_embeds), desc="Fetching k-NN")):
            men_embed = np.expand_dims(men_embed, axis=0)
            
            dict_type_idx_mapping, men_type_idx_mapping = None, None
            if use_types:
                entity_type = mention_data[men_query_idx]['type']
                dict_index = dict_indexes[entity_type]
                men_index = men_indexes[entity_type]
                dict_type_idx_mapping = dict_idxs_by_type[entity_type]
                men_type_idx_mapping = men_idxs_by_type[entity_type]
            
            # Fetch nearest entity candidate
            gold_idxs = mention_data[men_query_idx]["label_idxs"][:mention_data[men_query_idx]["n_labels"]]
            dict_cand_idx, dict_cand_score, recall_idx = get_query_nn(
                1, dict_embeds, dict_index, men_embed, searchK=params['recall_k'], gold_idxs=gold_idxs, type_idx_mapping=dict_type_idx_mapping)
            # Compute recall metric
            if recall_idx > -1:
                recall_idxs[recall_idx] += 1.
                for recall_k in recall_accuracy:
                    if recall_idx < recall_k:
                        recall_accuracy[recall_k] += 1.

            if not params['only_recall']:
                # Fetch (k+1) NN mention candidates
                men_cand_idxs, men_cand_scores = get_query_nn(
                    max_knn + 1, men_embeds, men_index, men_embed, type_idx_mapping=men_type_idx_mapping)
                # Filter candidates to remove mention query and keep only the top k candidates
                filter_mask = men_cand_idxs != men_query_idx
                men_cand_idxs, men_cand_scores = men_cand_idxs[filter_mask][:max_knn], men_cand_scores[filter_mask][:max_knn]

                # Add edges to the graphs
                for k in joint_graphs:
                    joint_graph = joint_graphs[k]
                    # Add mention-entity edge
                    joint_graph['rows'] = np.append(
                        joint_graph['rows'], [n_entities+men_query_idx])  # Mentions added at an offset of maximum entities
                    joint_graph['cols'] = np.append(
                        joint_graph['cols'], dict_cand_idx)
                    joint_graph['data'] = np.append(
                        joint_graph['data'], dict_cand_score)
                    if k > 0:
                        # Add mention-mention edges
                        joint_graph['rows'] = np.append(
                            joint_graph['rows'], [n_entities+men_query_idx]*len(men_cand_idxs[:k]))
                        joint_graph['cols'] = np.append(
                            joint_graph['cols'], n_entities+men_cand_idxs[:k])
                        joint_graph['data'] = np.append(
                            joint_graph['data'], men_cand_scores[:k])

        # Compute and print recall metric
        recall_idx_mode = np.argmax(recall_idxs)
        recall_idx_mode_prop = recall_idxs[recall_idx_mode]/np.sum(recall_idxs)
        logger.info(f"""
        Recall metrics (for {len(men_embeds)} queries):
        ---------------""")
        logger.info(f"highest recall idx = {recall_idx_mode} ({recall_idxs[recall_idx_mode]}/{np.sum(recall_idxs)} = {recall_idx_mode_prop})")
        for recall_k in recall_accuracy:
            recall_accuracy[recall_k] /= len(men_embeds)
            logger.info(f"recall@{recall_k} = {recall_accuracy[recall_k]}")

        if params['only_recall']:
            exit()

        # Pickle the graphs
        print("Saving joint graphs...")
        with open(graph_path, 'wb') as write_handle:
            pickle.dump(joint_graphs, write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

    results = []
    for k in joint_graphs:
        print(f"\nGraph (k={k}):")
        # Partition graph based on cluster-linking constraints
        partitioned_graph, clusters = partition_graph(
            joint_graphs[k], n_entities, directed_graph, return_clusters=True)
        # Infer predictions from clusters
        result = analyzeClusters(clusters, test_dictionary, mention_data, k)
        # Store result
        results.append(result)

    # Store results
    output_file_name = os.path.join(
        output_path, f"eval_results_{__import__('calendar').timegm(__import__('time').gmtime())}")
    result_overview = {
        'n_entities': results[0]['n_entities'],
        'n_mentions': results[0]['n_mentions'],
        'directed': directed_graph
    }

    try:
        for recall_k in recall_accuracy:
            result_overview[f'recall@{recall_k}'] = recall_accuracy[recall_k]
    except:
        logger.info("Recall data not available since graphs were loaded from disk")
    
    for r in results:
        k = r['knn_mentions']
        result_overview[f'accuracy@knn{k}'] = r['accuracy']
        logger.info(f"accuracy@knn{k} = {r['accuracy']}")
        output_file = f'{output_file_name}-{k}.json'
        with open(output_file, 'w') as f:
            json.dump(r, f, indent=2)
            print(f"\nPredictions @knn{k} saved at: {output_file}")
    with open(f'{output_file_name}.json', 'w') as f:
        json.dump(result_overview, f, indent=2)
        print(f"\nPredictions overview saved at: {output_file_name}.json")
Ejemplo n.º 10
0
def main(params):
    model_output_path = params["output_path"]
    if not os.path.exists(model_output_path):
        os.makedirs(model_output_path)
    logger = utils.get_logger(params["output_path"])

    pickle_src_path = params["pickle_src_path"]
    if pickle_src_path is None or not os.path.exists(pickle_src_path):
        pickle_src_path = model_output_path

    knn = params["knn"]
    use_types = params["use_types"]

    # Init model
    reranker = BiEncoderRanker(params)
    tokenizer = reranker.tokenizer
    model = reranker.model

    device = reranker.device
    n_gpu = reranker.n_gpu

    if params["gradient_accumulation_steps"] < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(params["gradient_accumulation_steps"]))

    # An effective batch size of `x`, when we are accumulating the gradient accross `y` batches will be achieved by having a batch size of `z = x / y`
    params["train_batch_size"] = (params["train_batch_size"] //
                                  params["gradient_accumulation_steps"])
    train_batch_size = params["train_batch_size"]
    eval_batch_size = params["eval_batch_size"]
    grad_acc_steps = params["gradient_accumulation_steps"]

    # Fix the random seeds
    seed = params["seed"]
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if reranker.n_gpu > 0:
        torch.cuda.manual_seed_all(seed)

    entity_dictionary_loaded = False
    entity_dictionary_pkl_path = os.path.join(pickle_src_path,
                                              'entity_dictionary.pickle')
    if os.path.isfile(entity_dictionary_pkl_path):
        print("Loading stored processed entity dictionary...")
        with open(entity_dictionary_pkl_path, 'rb') as read_handle:
            entity_dictionary = pickle.load(read_handle)
        entity_dictionary_loaded = True
    if not params["only_evaluate"]:
        # Load train data
        train_tensor_data_pkl_path = os.path.join(pickle_src_path,
                                                  'train_tensor_data.pickle')
        train_processed_data_pkl_path = os.path.join(
            pickle_src_path, 'train_processed_data.pickle')
        if os.path.isfile(train_tensor_data_pkl_path) and os.path.isfile(
                train_processed_data_pkl_path):
            print("Loading stored processed train data...")
            with open(train_tensor_data_pkl_path, 'rb') as read_handle:
                train_tensor_data = pickle.load(read_handle)
            with open(train_processed_data_pkl_path, 'rb') as read_handle:
                train_processed_data = pickle.load(read_handle)
        else:
            train_samples = utils.read_dataset("train", params["data_path"])
            if not entity_dictionary_loaded:
                with open(
                        os.path.join(params["data_path"], 'dictionary.pickle'),
                        'rb') as read_handle:
                    entity_dictionary = pickle.load(read_handle)

            # Check if dataset has multiple ground-truth labels
            mult_labels = "labels" in train_samples[0].keys()
            if params["filter_unlabeled"]:
                # Filter samples without gold entities
                train_samples = list(
                    filter(
                        lambda sample: (len(sample["labels"]) > 0)
                        if mult_labels else (sample["label"] is not None),
                        train_samples))
            logger.info("Read %d train samples." % len(train_samples))

            # For discovery experiment: Drop entities used in training that were dropped randomly from dev/test set
            if params["drop_entities"]:
                assert entity_dictionary_loaded
                drop_set_pkl_path = os.path.join(
                    pickle_src_path, 'drop_set_mention_data.pickle')
                with open(drop_set_pkl_path, 'rb') as read_handle:
                    drop_set_data = pickle.load(read_handle)
                drop_set_mention_gold_cui_idxs = list(
                    map(lambda x: x['label_idxs'][0], drop_set_data))
                ents_in_data = np.unique(drop_set_mention_gold_cui_idxs)
                ent_drop_prop = 0.1
                logger.info(
                    f"Dropping {ent_drop_prop*100}% of {len(ents_in_data)} entities found in drop set"
                )
                # Get entity indices to drop
                n_ents_dropped = int(ent_drop_prop * len(ents_in_data))
                rng = np.random.default_rng(seed=17)
                dropped_ent_idxs = rng.choice(ents_in_data,
                                              size=n_ents_dropped,
                                              replace=False)

                # Drop entities from dictionary (subsequent processing will automatically drop corresponding mentions)
                keep_mask = np.ones(len(entity_dictionary), dtype='bool')
                keep_mask[dropped_ent_idxs] = False
                entity_dictionary = np.array(entity_dictionary)[keep_mask]

            train_processed_data, entity_dictionary, train_tensor_data = data_process.process_mention_data(
                train_samples,
                entity_dictionary,
                tokenizer,
                params["max_context_length"],
                params["max_cand_length"],
                context_key=params["context_key"],
                multi_label_key="labels" if mult_labels else None,
                silent=params["silent"],
                logger=logger,
                debug=params["debug"],
                knn=knn,
                dictionary_processed=entity_dictionary_loaded)
            print("Saving processed train data...")
            if not entity_dictionary_loaded:
                with open(entity_dictionary_pkl_path, 'wb') as write_handle:
                    pickle.dump(entity_dictionary,
                                write_handle,
                                protocol=pickle.HIGHEST_PROTOCOL)
            with open(train_tensor_data_pkl_path, 'wb') as write_handle:
                pickle.dump(train_tensor_data,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            with open(train_processed_data_pkl_path, 'wb') as write_handle:
                pickle.dump(train_processed_data,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)

        # Store the query mention vectors
        train_men_vecs = train_tensor_data[:][0]

        if params["shuffle"]:
            train_sampler = RandomSampler(train_tensor_data)
        else:
            train_sampler = SequentialSampler(train_tensor_data)

        train_dataloader = DataLoader(train_tensor_data,
                                      sampler=train_sampler,
                                      batch_size=train_batch_size)

    # Store the entity dictionary vectors
    entity_dict_vecs = torch.tensor(list(
        map(lambda x: x['ids'], entity_dictionary)),
                                    dtype=torch.long)

    # Load eval data
    valid_tensor_data_pkl_path = os.path.join(pickle_src_path,
                                              'valid_tensor_data.pickle')
    valid_processed_data_pkl_path = os.path.join(
        pickle_src_path, 'valid_processed_data.pickle')
    if os.path.isfile(valid_tensor_data_pkl_path) and os.path.isfile(
            valid_processed_data_pkl_path):
        print("Loading stored processed valid data...")
        with open(valid_tensor_data_pkl_path, 'rb') as read_handle:
            valid_tensor_data = pickle.load(read_handle)
        with open(valid_processed_data_pkl_path, 'rb') as read_handle:
            valid_processed_data = pickle.load(read_handle)
    else:
        valid_samples = utils.read_dataset("valid", params["data_path"])
        # Check if dataset has multiple ground-truth labels
        mult_labels = "labels" in valid_samples[0].keys()
        # Filter samples without gold entities
        valid_samples = list(
            filter(
                lambda sample: (len(sample["labels"]) > 0) if mult_labels else
                (sample["label"] is not None), valid_samples))
        logger.info("Read %d valid samples." % len(valid_samples))

        valid_processed_data, _, valid_tensor_data = data_process.process_mention_data(
            valid_samples,
            entity_dictionary,
            tokenizer,
            params["max_context_length"],
            params["max_cand_length"],
            context_key=params["context_key"],
            multi_label_key="labels" if mult_labels else None,
            silent=params["silent"],
            logger=logger,
            debug=params["debug"],
            knn=knn,
            dictionary_processed=True)
        print("Saving processed valid data...")
        with open(valid_tensor_data_pkl_path, 'wb') as write_handle:
            pickle.dump(valid_tensor_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)
        with open(valid_processed_data_pkl_path, 'wb') as write_handle:
            pickle.dump(valid_processed_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

    # Store the query mention vectors
    valid_men_vecs = valid_tensor_data[:][0]

    # valid_sampler = SequentialSampler(valid_tensor_data)
    # valid_dataloader = DataLoader(
    #     valid_tensor_data, sampler=valid_sampler, batch_size=eval_batch_size
    # )

    if params["only_evaluate"]:
        evaluate(reranker,
                 entity_dict_vecs,
                 valid_men_vecs,
                 device=device,
                 logger=logger,
                 knn=knn,
                 n_gpu=n_gpu,
                 entity_data=entity_dictionary,
                 query_data=valid_processed_data,
                 silent=params["silent"],
                 use_types=use_types or params["use_types_for_eval"],
                 embed_batch_size=params['embed_batch_size'],
                 force_exact_search=use_types or params["use_types_for_eval"]
                 or params["force_exact_search"],
                 probe_mult_factor=params['probe_mult_factor'])
        exit()

    time_start = time.time()
    utils.write_to_file(os.path.join(model_output_path, "training_params.txt"),
                        str(params))
    logger.info("Starting training")
    logger.info("device: {} n_gpu: {}, data_parallel: {}".format(
        device, n_gpu, params["data_parallel"]))

    # Set model to training mode
    optimizer = get_optimizer(model, params)
    scheduler = get_scheduler(params, optimizer, len(train_tensor_data),
                              logger)
    best_epoch_idx = -1
    best_score = -1
    num_train_epochs = params["num_train_epochs"]

    init_base_model_run = True if params.get("path_to_model",
                                             None) is None else False
    init_run_pkl_path = os.path.join(
        pickle_src_path, f'init_run_{"type" if use_types else "notype"}.t7')

    dict_embed_data = None

    for epoch_idx in trange(int(num_train_epochs), desc="Epoch"):
        model.train()
        torch.cuda.empty_cache()
        tr_loss = 0
        results = None

        # Check if embeddings and index can be loaded
        init_run_data_loaded = False
        if init_base_model_run:
            if os.path.isfile(init_run_pkl_path):
                logger.info('Loading init run data')
                init_run_data = torch.load(init_run_pkl_path)
                init_run_data_loaded = True
        load_stored_data = init_base_model_run and init_run_data_loaded

        # Compute mention and entity embeddings at the start of each epoch
        if use_types:
            if load_stored_data:
                train_dict_embeddings, dict_idxs_by_type = init_run_data[
                    'train_dict_embeddings'], init_run_data[
                        'dict_idxs_by_type']
                train_dict_indexes = data_process.get_index_from_embeds(
                    train_dict_embeddings,
                    dict_idxs_by_type,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
                train_men_embeddings, men_idxs_by_type = init_run_data[
                    'train_men_embeddings'], init_run_data['men_idxs_by_type']
                train_men_indexes = data_process.get_index_from_embeds(
                    train_men_embeddings,
                    men_idxs_by_type,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
            else:
                logger.info('Embedding and indexing')
                if dict_embed_data is not None:
                    train_dict_embeddings, train_dict_indexes, dict_idxs_by_type = dict_embed_data[
                        'dict_embeds'], dict_embed_data[
                            'dict_indexes'], dict_embed_data[
                                'dict_idxs_by_type']
                else:
                    train_dict_embeddings, train_dict_indexes, dict_idxs_by_type = data_process.embed_and_index(
                        reranker,
                        entity_dict_vecs,
                        encoder_type="candidate",
                        n_gpu=n_gpu,
                        corpus=entity_dictionary,
                        force_exact_search=params['force_exact_search'],
                        batch_size=params['embed_batch_size'],
                        probe_mult_factor=params['probe_mult_factor'])
                train_men_embeddings, train_men_indexes, men_idxs_by_type = data_process.embed_and_index(
                    reranker,
                    train_men_vecs,
                    encoder_type="context",
                    n_gpu=n_gpu,
                    corpus=train_processed_data,
                    force_exact_search=params['force_exact_search'],
                    batch_size=params['embed_batch_size'],
                    probe_mult_factor=params['probe_mult_factor'])
        else:
            if load_stored_data:
                train_dict_embeddings = init_run_data['train_dict_embeddings']
                train_dict_index = data_process.get_index_from_embeds(
                    train_dict_embeddings,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
                train_men_embeddings = init_run_data['train_men_embeddings']
                train_men_index = data_process.get_index_from_embeds(
                    train_men_embeddings,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
            else:
                logger.info('Embedding and indexing')
                if dict_embed_data is not None:
                    train_dict_embeddings, train_dict_index = dict_embed_data[
                        'dict_embeds'], dict_embed_data['dict_index']
                else:
                    train_dict_embeddings, train_dict_index = data_process.embed_and_index(
                        reranker,
                        entity_dict_vecs,
                        encoder_type="candidate",
                        n_gpu=n_gpu,
                        force_exact_search=params['force_exact_search'],
                        batch_size=params['embed_batch_size'],
                        probe_mult_factor=params['probe_mult_factor'])
                train_men_embeddings, train_men_index = data_process.embed_and_index(
                    reranker,
                    train_men_vecs,
                    encoder_type="context",
                    n_gpu=n_gpu,
                    force_exact_search=params['force_exact_search'],
                    batch_size=params['embed_batch_size'],
                    probe_mult_factor=params['probe_mult_factor'])

        # Save the initial embeddings and index if this is the first run and data isn't persistent
        if init_base_model_run and not load_stored_data:
            init_run_data = {}
            init_run_data['train_dict_embeddings'] = train_dict_embeddings
            init_run_data['train_men_embeddings'] = train_men_embeddings
            if use_types:
                init_run_data['dict_idxs_by_type'] = dict_idxs_by_type
                init_run_data['men_idxs_by_type'] = men_idxs_by_type
            # NOTE: Cannot pickle faiss index because it is a SwigPyObject
            torch.save(init_run_data,
                       init_run_pkl_path,
                       pickle_protocol=pickle.HIGHEST_PROTOCOL)

        init_base_model_run = False

        if params["silent"]:
            iter_ = train_dataloader
        else:
            iter_ = tqdm(train_dataloader, desc="Batch")

        logger.info("Starting KNN search...")
        if not use_types:
            _, dict_nns = train_dict_index.search(train_men_embeddings, knn)
        else:
            dict_nns = np.zeros((len(train_men_embeddings), knn))
            for entity_type in train_men_indexes:
                men_embeds_by_type = train_men_embeddings[
                    men_idxs_by_type[entity_type]]
                _, dict_nns_by_type = train_dict_indexes[entity_type].search(
                    men_embeds_by_type, knn)
                dict_nns_idxs = np.array(
                    list(
                        map(lambda x: dict_idxs_by_type[entity_type][x],
                            dict_nns_by_type)))
                for i, idx in enumerate(men_idxs_by_type[entity_type]):
                    dict_nns[idx] = dict_nns_idxs[i]
        logger.info("Search finished")

        for step, batch in enumerate(iter_):
            batch = tuple(t.to(device) for t in batch)
            context_inputs, candidate_idxs, n_gold, mention_idxs = batch
            mention_embeddings = train_men_embeddings[mention_idxs.cpu()]

            if len(mention_embeddings.shape) == 1:
                mention_embeddings = np.expand_dims(mention_embeddings, axis=0)

            # context_inputs: Shape: batch x token_len
            candidate_inputs = np.array(
                [], dtype=np.int)  # Shape: (batch*knn) x token_len
            label_inputs = torch.tensor(
                [[1] + [0] * (knn - 1)] * n_gold.sum(),
                dtype=torch.float32)  # Shape: batch(with split rows) x knn
            context_inputs_split = torch.zeros(
                (label_inputs.size(0), context_inputs.size(1)),
                dtype=torch.long)  # Shape: batch(with split rows) x token_len
            # label_inputs = (candidate_idxs >= 0).type(torch.float32) # Shape: batch x knn

            for i, m_embed in enumerate(mention_embeddings):
                knn_dict_idxs = dict_nns[mention_idxs[i]]
                knn_dict_idxs = knn_dict_idxs.astype(np.int64).flatten()
                gold_idxs = candidate_idxs[i][:n_gold[i]].cpu()
                for ng, gold_idx in enumerate(gold_idxs):
                    context_inputs_split[i + ng] = context_inputs[i]
                    candidate_inputs = np.concatenate(
                        (candidate_inputs,
                         np.concatenate(
                             ([gold_idx],
                              knn_dict_idxs[~np.isin(knn_dict_idxs, gold_idxs)]
                              ))[:knn]))
            candidate_inputs = torch.tensor(
                list(
                    map(lambda x: entity_dict_vecs[x].numpy(),
                        candidate_inputs))).cuda()
            context_inputs_split = context_inputs_split.cuda()
            label_inputs = label_inputs.cuda()

            loss, _ = reranker(context_inputs_split,
                               candidate_inputs,
                               label_inputs,
                               pos_neg_loss=params["pos_neg_loss"])

            if grad_acc_steps > 1:
                loss = loss / grad_acc_steps

            tr_loss += loss.item()

            if (step + 1) % (params["print_interval"] * grad_acc_steps) == 0:
                logger.info("Step {} - epoch {} average loss: {}\n".format(
                    step,
                    epoch_idx,
                    tr_loss / (params["print_interval"] * grad_acc_steps),
                ))
                tr_loss = 0

            loss.backward()

            if (step + 1) % grad_acc_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               params["max_grad_norm"])
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

            if (step + 1) % (params["eval_interval"] * grad_acc_steps) == 0:
                logger.info("Evaluation on the development dataset")
                evaluate(reranker,
                         entity_dict_vecs,
                         valid_men_vecs,
                         device=device,
                         logger=logger,
                         knn=knn,
                         n_gpu=n_gpu,
                         entity_data=entity_dictionary,
                         query_data=valid_processed_data,
                         silent=params["silent"],
                         use_types=use_types or params["use_types_for_eval"],
                         embed_batch_size=params['embed_batch_size'],
                         force_exact_search=use_types
                         or params["use_types_for_eval"]
                         or params["force_exact_search"],
                         probe_mult_factor=params['probe_mult_factor'])
                model.train()
                logger.info("\n")

        logger.info("***** Saving fine-tuned model *****")
        epoch_output_folder_path = os.path.join(model_output_path,
                                                "epoch_{}".format(epoch_idx))
        utils.save_model(model, tokenizer, epoch_output_folder_path)
        logger.info(f"Model saved at {epoch_output_folder_path}")

        normalized_accuracy, dict_embed_data = evaluate(
            reranker,
            entity_dict_vecs,
            valid_men_vecs,
            device=device,
            logger=logger,
            knn=knn,
            n_gpu=n_gpu,
            entity_data=entity_dictionary,
            query_data=valid_processed_data,
            silent=params["silent"],
            use_types=use_types or params["use_types_for_eval"],
            embed_batch_size=params['embed_batch_size'],
            force_exact_search=use_types or params["use_types_for_eval"]
            or params["force_exact_search"],
            probe_mult_factor=params['probe_mult_factor'])

        ls = [best_score, normalized_accuracy]
        li = [best_epoch_idx, epoch_idx]

        best_score = ls[np.argmax(ls)]
        best_epoch_idx = li[np.argmax(ls)]
        logger.info("\n")

    execution_time = (time.time() - time_start) / 60
    utils.write_to_file(
        os.path.join(model_output_path, "training_time.txt"),
        "The training took {} minutes\n".format(execution_time),
    )
    logger.info("The training took {} minutes\n".format(execution_time))

    # save the best model in the parent_dir
    logger.info("Best performance in epoch: {}".format(best_epoch_idx))
    params["path_to_model"] = os.path.join(model_output_path,
                                           "epoch_{}".format(best_epoch_idx))
    utils.save_model(reranker.model, tokenizer, model_output_path)
    logger.info(f"Best model saved at {model_output_path}")
Ejemplo n.º 11
0
def main(params):
    output_path = params["output_path"]
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = utils.get_logger(params["output_path"])

    # Init model 
    reranker = BiEncoderRanker(params)
    tokenizer = reranker.tokenizer

    # laod entities
    entity_dict, entity_json = load_entity_dict(logger, params)

    # load tfidf candidates
    tfidf_cand_dict = read_tfidf_cands(params["data_path"], params["mode"])

    # load mentions
    test_samples = utils.read_dataset(params["mode"], params["data_path"])
    logger.info("Read %d test samples." % len(test_samples))

    # get only the cands we need to tokenize
    cand_ids = [c for l in tfidf_cand_dict.values() for c in l]
    cand_ids.extend([x["label_umls_cuid"] for x in test_samples])
    cand_ids = list(set(cand_ids))
    num_cands = len(cand_ids)

    # tokenize the candidates
    cand_uid_map = {c : i for i, c in enumerate(cand_ids)}
    candidate_pool = get_candidate_pool_tensor(
        [entity_dict[c] for c in cand_ids],
        tokenizer,
        params["max_cand_length"],
        logger
    )

    # create mention maps
    ctxt_uid_map = {x["mm_mention_id"] : i + num_cands
                        for i, x in enumerate(test_samples)}
    ctxt_cand_map = {x["mm_mention_id"] : x["label_umls_cuid"]
                        for x in test_samples}
    ctxt_doc_map = {x["mm_mention_id"] : x["context_doc_id"]
                        for x in test_samples}
    doc_ctxt_map = defaultdict(list)
    for c, d in ctxt_doc_map.items():
        doc_ctxt_map[d].append(c)

    # create text maps for investigative evaluation
    uid_to_json = {
        uid : entity_json[cuid] for cuid, uid in cand_uid_map.items()
    }
    uid_to_json.update({i+num_cands : x for i, x in enumerate(test_samples)})

    # tokenize the contexts
    test_data, test_tensor_data = data.process_mention_data(
        test_samples,
        tokenizer,
        params["max_context_length"],
        params["max_cand_length"],
        context_key=params['context_key'],
        silent=params["silent"],
        logger=logger,
        debug=params["debug"],
    )
    context_pool = test_data["context_vecs"]
    
    # create output variables
    contexts = context_pool
    context_uids = torch.LongTensor(list(ctxt_uid_map.values()))

    pos_coref_ctxts = []
    pos_coref_ctxt_uids = []
    for i, c in enumerate(ctxt_doc_map.keys()):
        assert ctxt_uid_map[c] == i + num_cands
        doc = ctxt_doc_map[c]
        coref_ctxts = [x for x in doc_ctxt_map[doc]
                          if x != c and ctxt_cand_map[x] == ctxt_cand_map[c]]
        coref_ctxt_uids = [ctxt_uid_map[x] for x in coref_ctxts]
        coref_ctxt_idxs = [x - num_cands for x in coref_ctxt_uids]
        pos_coref_ctxts.append(context_pool[coref_ctxt_idxs])
        pos_coref_ctxt_uids.append(torch.LongTensor(coref_ctxt_uids))

    knn_ctxts = []
    knn_ctxt_uids = []
    for i, c in enumerate(ctxt_doc_map.keys()):
        assert ctxt_uid_map[c] == i + num_cands
        doc = ctxt_doc_map[c]
        wdoc_ctxts = [x for x in doc_ctxt_map[doc] if x != c]
        wdoc_ctxt_uids = [ctxt_uid_map[x] for x in wdoc_ctxts]
        wdoc_ctxt_idxs = [x - num_cands for x in wdoc_ctxt_uids]
        knn_ctxts.append(context_pool[wdoc_ctxt_idxs])
        knn_ctxt_uids.append(torch.LongTensor(wdoc_ctxt_uids))
        
    pos_cands = []
    pos_cand_uids = []
    for i, c in enumerate(ctxt_cand_map.keys()):
        assert ctxt_uid_map[c] == i + num_cands
        pos_cands.append(candidate_pool[cand_uid_map[ctxt_cand_map[c]]])
        pos_cand_uids.append(torch.LongTensor([cand_uid_map[ctxt_cand_map[c]]]))

    knn_cands = []
    knn_cand_uids = []
    for i, c in enumerate(ctxt_cand_map.keys()):
        assert ctxt_uid_map[c] == i + num_cands
        tfidf_cands = tfidf_cand_dict.get(c, [])
        tfidf_cand_uids = [cand_uid_map[x] for x in tfidf_cands]
        knn_cands.append(candidate_pool[tfidf_cand_uids])
        knn_cand_uids.append(torch.LongTensor(tfidf_cand_uids))

    tfidf_data = {
        "contexts" : contexts,
        "context_uids":  context_uids,
        "pos_coref_ctxts":  pos_coref_ctxts,
        "pos_coref_ctxt_uids":  pos_coref_ctxt_uids,
        "knn_ctxts":  knn_ctxts,
        "knn_ctxt_uids":  knn_ctxt_uids,
        "pos_cands":  pos_cands,
        "pos_cand_uids":  pos_cand_uids,
        "knn_cands":  knn_cands,
        "knn_cand_uids":  knn_cand_uids,
        "uid_to_json":  uid_to_json,
    }
    
    save_data_path = os.path.join(
        params['output_path'], 
        'joint_candidates_%s_tfidf.t7' % (params['mode'])
    )
    torch.save(tfidf_data, save_data_path)
Ejemplo n.º 12
0
def main(params):
    output_path = params["output_path"]
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = utils.get_logger(params["output_path"], 'log-eval')

    pickle_src_path = params["pickle_src_path"]
    if pickle_src_path is None or not os.path.exists(pickle_src_path):
        pickle_src_path = output_path

    embed_data_path = params["embed_data_path"]
    if embed_data_path is None or not os.path.exists(embed_data_path):
        embed_data_path = output_path

    # Init model
    reranker = BiEncoderRanker(params)
    reranker.model.eval()
    tokenizer = reranker.tokenizer
    n_gpu = reranker.n_gpu

    knn = params["knn"]
    use_types = params["use_types"]
    within_doc = params["within_doc"]
    data_split = params["data_split"]  # Default = "test"

    # Load test data
    test_samples = None
    entity_dictionary_loaded = False
    test_dictionary_pkl_path = os.path.join(pickle_src_path,
                                            'test_dictionary.pickle')
    test_tensor_data_pkl_path = os.path.join(pickle_src_path,
                                             'test_tensor_data.pickle')
    test_mention_data_pkl_path = os.path.join(pickle_src_path,
                                              'test_mention_data.pickle')
    if params['transductive']:
        train_tensor_data_pkl_path = os.path.join(pickle_src_path,
                                                  'train_tensor_data.pickle')
        train_mention_data_pkl_path = os.path.join(
            pickle_src_path, 'train_mention_data.pickle')
    if os.path.isfile(test_dictionary_pkl_path):
        print("Loading stored processed entity dictionary...")
        with open(test_dictionary_pkl_path, 'rb') as read_handle:
            test_dictionary = pickle.load(read_handle)
        entity_dictionary_loaded = True
    if os.path.isfile(test_tensor_data_pkl_path) and os.path.isfile(
            test_mention_data_pkl_path):
        print("Loading stored processed test data...")
        with open(test_tensor_data_pkl_path, 'rb') as read_handle:
            test_tensor_data = pickle.load(read_handle)
        with open(test_mention_data_pkl_path, 'rb') as read_handle:
            mention_data = pickle.load(read_handle)
    else:
        test_samples = utils.read_dataset(data_split, params["data_path"])
        if not entity_dictionary_loaded:
            with open(os.path.join(params["data_path"], 'dictionary.pickle'),
                      'rb') as read_handle:
                test_dictionary = pickle.load(read_handle)

        # Check if dataset has multiple ground-truth labels
        mult_labels = "labels" in test_samples[0].keys()
        if params["filter_unlabeled"]:
            # Filter samples without gold entities
            test_samples = list(
                filter(
                    lambda sample: (len(sample["labels"]) > 0)
                    if mult_labels else (sample["label"] is not None),
                    test_samples))
        logger.info("Read %d test samples." % len(test_samples))

        mention_data, test_dictionary, test_tensor_data = data_process.process_mention_data(
            test_samples,
            test_dictionary,
            tokenizer,
            params["max_context_length"],
            params["max_cand_length"],
            multi_label_key="labels" if mult_labels else None,
            context_key=params["context_key"],
            silent=params["silent"],
            logger=logger,
            debug=params["debug"],
            knn=knn,
            dictionary_processed=entity_dictionary_loaded)
        print("Saving processed test data...")
        if not entity_dictionary_loaded:
            with open(test_dictionary_pkl_path, 'wb') as write_handle:
                pickle.dump(test_dictionary,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            entity_dictionary_loaded = True
        with open(test_tensor_data_pkl_path, 'wb') as write_handle:
            pickle.dump(test_tensor_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)
        with open(test_mention_data_pkl_path, 'wb') as write_handle:
            pickle.dump(mention_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

    # Store test dictionary token ids
    test_dict_vecs = torch.tensor(list(map(lambda x: x['ids'],
                                           test_dictionary)),
                                  dtype=torch.long)
    # Store test mention token ids
    test_men_vecs = test_tensor_data[:][0]

    n_entities = len(test_dict_vecs)
    n_mentions = len(test_tensor_data)

    if within_doc:
        if test_samples is None:
            test_samples, _ = read_data(data_split, params, logger)
        test_context_doc_ids = [s['context_doc_id'] for s in test_samples]

    if params["transductive"]:
        if os.path.isfile(train_tensor_data_pkl_path) and os.path.isfile(
                train_mention_data_pkl_path):
            print("Loading stored processed train data...")
            with open(train_tensor_data_pkl_path, 'rb') as read_handle:
                train_tensor_data = pickle.load(read_handle)
            with open(train_mention_data_pkl_path, 'rb') as read_handle:
                train_mention_data = pickle.load(read_handle)
        else:
            train_samples = utils.read_dataset('train', params["data_path"])

            # Check if dataset has multiple ground-truth labels
            mult_labels = "labels" in train_samples[0].keys()
            logger.info("Read %d test samples." % len(test_samples))

            train_mention_data, _, train_tensor_data = data_process.process_mention_data(
                train_samples,
                test_dictionary,
                tokenizer,
                params["max_context_length"],
                params["max_cand_length"],
                multi_label_key="labels" if mult_labels else None,
                context_key=params["context_key"],
                silent=params["silent"],
                logger=logger,
                debug=params["debug"],
                knn=knn,
                dictionary_processed=entity_dictionary_loaded)
            print("Saving processed train data...")
            with open(train_tensor_data_pkl_path, 'wb') as write_handle:
                pickle.dump(train_tensor_data,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            with open(train_mention_data_pkl_path, 'wb') as write_handle:
                pickle.dump(train_mention_data,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)

        # Store train mention token ids
        train_men_vecs = train_tensor_data[:][0]
        n_mentions += len(train_tensor_data)
        n_train_mentions = len(train_tensor_data)

    # Values of k to run the evaluation against
    knn_vals = [0] + [2**i for i in range(int(math.log(knn, 2)) + 1)]
    # Store the maximum evaluation k
    max_knn = knn_vals[-1]

    time_start = time.time()

    # Check if graphs are already built
    graph_path = os.path.join(output_path, 'graphs.pickle')
    if not params['only_recall'] and os.path.isfile(graph_path):
        print("Loading stored joint graphs...")
        with open(graph_path, 'rb') as read_handle:
            joint_graphs = pickle.load(read_handle)
    else:
        # Initialize graphs to store mention-mention and mention-entity similarity score edges;
        # Keyed on k, the number of nearest mentions retrieved
        joint_graphs = {}
        for k in knn_vals:
            joint_graphs[k] = {
                'rows': np.array([]),
                'cols': np.array([]),
                'data': np.array([]),
                'shape': (n_entities + n_mentions, n_entities + n_mentions)
            }

        # Check and load stored embedding data
        embed_data_path = os.path.join(embed_data_path, 'embed_data.t7')
        embed_data = None
        if os.path.isfile(embed_data_path):
            embed_data = torch.load(embed_data_path)

        if use_types:
            if embed_data is not None:
                logger.info('Loading stored embeddings and computing indexes')
                dict_embeds = embed_data['dict_embeds']
                if 'dict_idxs_by_type' in embed_data:
                    dict_idxs_by_type = embed_data['dict_idxs_by_type']
                else:
                    dict_idxs_by_type = data_process.get_idxs_by_type(
                        test_dictionary)
                dict_indexes = data_process.get_index_from_embeds(
                    dict_embeds,
                    dict_idxs_by_type,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
                men_embeds = embed_data['men_embeds']
                if 'men_idxs_by_type' in embed_data:
                    men_idxs_by_type = embed_data['men_idxs_by_type']
                else:
                    men_idxs_by_type = data_process.get_idxs_by_type(
                        mention_data)
                men_indexes = data_process.get_index_from_embeds(
                    men_embeds,
                    men_idxs_by_type,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
            else:
                logger.info("Dictionary: Embedding and building index")
                dict_embeds, dict_indexes, dict_idxs_by_type = data_process.embed_and_index(
                    reranker,
                    test_dict_vecs,
                    encoder_type="candidate",
                    n_gpu=n_gpu,
                    corpus=test_dictionary,
                    force_exact_search=params['force_exact_search'],
                    batch_size=params['embed_batch_size'],
                    probe_mult_factor=params['probe_mult_factor'])
                logger.info("Queries: Embedding and building index")
                vecs = test_men_vecs
                men_data = mention_data
                if params['transductive']:
                    vecs = torch.cat((train_men_vecs, vecs), dim=0)
                    men_data = train_mention_data + mention_data
                men_embeds, men_indexes, men_idxs_by_type = data_process.embed_and_index(
                    reranker,
                    vecs,
                    encoder_type="context",
                    n_gpu=n_gpu,
                    corpus=men_data,
                    force_exact_search=params['force_exact_search'],
                    batch_size=params['embed_batch_size'],
                    probe_mult_factor=params['probe_mult_factor'])
        else:
            if embed_data is not None:
                logger.info('Loading stored embeddings and computing indexes')
                dict_embeds = embed_data['dict_embeds']
                dict_index = data_process.get_index_from_embeds(
                    dict_embeds,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
                men_embeds = embed_data['men_embeds']
                men_index = data_process.get_index_from_embeds(
                    men_embeds,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
            else:
                logger.info("Dictionary: Embedding and building index")
                dict_embeds, dict_index = data_process.embed_and_index(
                    reranker,
                    test_dict_vecs,
                    'candidate',
                    n_gpu=n_gpu,
                    force_exact_search=params['force_exact_search'],
                    batch_size=params['embed_batch_size'],
                    probe_mult_factor=params['probe_mult_factor'])
                logger.info("Queries: Embedding and building index")
                vecs = test_men_vecs
                if params['transductive']:
                    vecs = torch.cat((train_men_vecs, vecs), dim=0)
                men_embeds, men_index = data_process.embed_and_index(
                    reranker,
                    vecs,
                    'context',
                    n_gpu=n_gpu,
                    force_exact_search=params['force_exact_search'],
                    batch_size=params['embed_batch_size'],
                    probe_mult_factor=params['probe_mult_factor'])

        # Save computed embedding data if not loaded from disk
        if embed_data is None:
            embed_data = {}
            embed_data['dict_embeds'] = dict_embeds
            embed_data['men_embeds'] = men_embeds
            if use_types:
                embed_data['dict_idxs_by_type'] = dict_idxs_by_type
                embed_data['men_idxs_by_type'] = men_idxs_by_type
            # NOTE: Cannot pickle faiss index because it is a SwigPyObject
            torch.save(embed_data,
                       embed_data_path,
                       pickle_protocol=pickle.HIGHEST_PROTOCOL)

        recall_accuracy = {
            2**i: 0
            for i in range(int(math.log(params['recall_k'], 2)) + 1)
        }
        recall_idxs = [0.] * params['recall_k']

        logger.info("Starting KNN search...")
        # Fetch recall_k (default 16) knn entities for all mentions
        # Fetch (k+1) NN mention candidates
        if not use_types:
            _men_embeds = men_embeds
            if params['transductive']:
                _men_embeds = _men_embeds[n_train_mentions:]
            nn_ent_dists, nn_ent_idxs = dict_index.search(
                _men_embeds, params['recall_k'])
            n_mens_to_fetch = len(_men_embeds) if within_doc else max_knn + 1
            nn_men_dists, nn_men_idxs = men_index.search(
                _men_embeds, n_mens_to_fetch)
        else:
            query_len = len(men_embeds) - (n_train_mentions
                                           if params['transductive'] else 0)
            nn_ent_idxs = np.zeros((query_len, params['recall_k']))
            nn_ent_dists = np.zeros((query_len, params['recall_k']),
                                    dtype='float64')
            nn_men_idxs = -1 * np.ones((query_len, query_len), dtype=int)
            nn_men_dists = -1 * np.ones(
                (query_len, query_len), dtype='float64')
            for entity_type in men_indexes:
                men_embeds_by_type = men_embeds[men_idxs_by_type[entity_type][
                    men_idxs_by_type[entity_type] >=
                    n_train_mentions]] if params[
                        'transductive'] else men_embeds[
                            men_idxs_by_type[entity_type]]
                nn_ent_dists_by_type, nn_ent_idxs_by_type = dict_indexes[
                    entity_type].search(men_embeds_by_type, params['recall_k'])
                nn_ent_idxs_by_type = np.array(
                    list(
                        map(lambda x: dict_idxs_by_type[entity_type][x],
                            nn_ent_idxs_by_type)))
                n_mens_to_fetch = len(
                    men_embeds_by_type) if within_doc else max_knn + 1
                nn_men_dists_by_type, nn_men_idxs_by_type = men_indexes[
                    entity_type].search(
                        men_embeds_by_type,
                        min(n_mens_to_fetch, len(men_embeds_by_type)))
                nn_men_idxs_by_type = np.array(
                    list(
                        map(lambda x: men_idxs_by_type[entity_type][x],
                            nn_men_idxs_by_type)))
                i = -1
                for idx in men_idxs_by_type[entity_type]:
                    if params['transductive']:
                        idx -= n_train_mentions
                    if idx < 0:
                        continue
                    i += 1
                    nn_ent_idxs[idx] = nn_ent_idxs_by_type[i]
                    nn_ent_dists[idx] = nn_ent_dists_by_type[i]
                    nn_men_idxs[idx][:len(nn_men_idxs_by_type[i]
                                          )] = nn_men_idxs_by_type[i]
                    nn_men_dists[idx][:len(nn_men_dists_by_type[i]
                                           )] = nn_men_dists_by_type[i]
        logger.info("Search finished")

        logger.info('Building graphs')
        # Find the most similar entity and k-nn mentions for each mention query
        for idx in range(len(nn_ent_idxs)):
            # Get nearest entity candidate
            dict_cand_idx = nn_ent_idxs[idx][0]
            dict_cand_score = nn_ent_dists[idx][0]
            # Compute recall metric
            gold_idxs = mention_data[idx][
                "label_idxs"][:mention_data[idx]["n_labels"]]
            recall_idx = np.argwhere(nn_ent_idxs[idx] == gold_idxs[0])
            if len(recall_idx) != 0:
                recall_idx = int(recall_idx)
                recall_idxs[recall_idx] += 1.
                for recall_k in recall_accuracy:
                    if recall_idx < recall_k:
                        recall_accuracy[recall_k] += 1.
            if not params['only_recall']:
                filter_mask_neg1 = nn_men_idxs[idx] != -1
                men_cand_idxs = nn_men_idxs[idx][filter_mask_neg1]
                men_cand_scores = nn_men_dists[idx][filter_mask_neg1]

                if within_doc:
                    men_cand_idxs, wd_mask = filter_by_context_doc_id(
                        men_cand_idxs,
                        test_context_doc_ids[idx],
                        test_context_doc_ids,
                        return_numpy=True)
                    men_cand_scores = men_cand_scores[wd_mask]

                # Filter candidates to remove mention query and keep only the top k candidates
                filter_mask = men_cand_idxs != idx
                men_cand_idxs, men_cand_scores = men_cand_idxs[
                    filter_mask][:max_knn], men_cand_scores[
                        filter_mask][:max_knn]

                if params['transductive']:
                    idx += n_train_mentions
                # Add edges to the graphs
                for k in joint_graphs:
                    joint_graph = joint_graphs[k]
                    # Add mention-entity edge
                    joint_graph['rows'] = np.append(
                        joint_graph['rows'],
                        [n_entities + idx
                         ])  # Mentions added at an offset of maximum entities
                    joint_graph['cols'] = np.append(joint_graph['cols'],
                                                    dict_cand_idx)
                    joint_graph['data'] = np.append(joint_graph['data'],
                                                    dict_cand_score)
                    if k > 0:
                        # Add mention-mention edges
                        joint_graph['rows'] = np.append(
                            joint_graph['rows'],
                            [n_entities + idx] * len(men_cand_idxs[:k]))
                        joint_graph['cols'] = np.append(
                            joint_graph['cols'],
                            n_entities + men_cand_idxs[:k])
                        joint_graph['data'] = np.append(
                            joint_graph['data'], men_cand_scores[:k])

        if params['transductive']:
            # Add positive infinity mention-entity edges from training queries to labeled entities
            for idx, train_men in enumerate(train_mention_data):
                dict_cand_idx = train_men["label_idxs"][0]
                for k in joint_graphs:
                    joint_graph = joint_graphs[k]
                    joint_graph['rows'] = np.append(
                        joint_graph['rows'],
                        [n_entities + idx
                         ])  # Mentions added at an offset of maximum entities
                    joint_graph['cols'] = np.append(joint_graph['cols'],
                                                    dict_cand_idx)
                    joint_graph['data'] = np.append(joint_graph['data'],
                                                    float('inf'))

        # Compute and print recall metric
        recall_idx_mode = np.argmax(recall_idxs)
        recall_idx_mode_prop = recall_idxs[recall_idx_mode] / np.sum(
            recall_idxs)
        logger.info(f"""
        Recall metrics (for {len(nn_ent_idxs)} queries):
        ---------------""")
        logger.info(
            f"highest recall idx = {recall_idx_mode} ({recall_idxs[recall_idx_mode]}/{np.sum(recall_idxs)} = {recall_idx_mode_prop})"
        )
        for recall_k in recall_accuracy:
            recall_accuracy[recall_k] /= len(nn_ent_idxs)
            logger.info(f"recall@{recall_k} = {recall_accuracy[recall_k]}")

        if params['only_recall']:
            exit()

        # Pickle the graphs
        print("Saving joint graphs...")
        with open(graph_path, 'wb') as write_handle:
            pickle.dump(joint_graphs,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

        if params['only_embed_and_build']:
            logger.info(f"Saved embedding data at: {embed_data_path}")
            logger.info(f"Saved graphs at: {graph_path}")
            exit()

    graph_mode = params.get('graph_mode', None)

    result_overview = {
        'n_entities':
        n_entities,
        'n_mentions':
        n_mentions - (n_train_mentions if params['transductive'] else 0)
    }
    results = {}
    if graph_mode is None or graph_mode not in ['directed', 'undirected']:
        results['directed'] = []
        results['undirected'] = []
    else:
        results[graph_mode] = []

    knn_fetch_time = time.time() - time_start
    graph_processing_time = time.time()
    n_graphs_processed = 0.

    for mode in results:
        print(f'\nEvaluation mode: {mode.upper()}')
        for k in joint_graphs:
            if k <= knn:
                print(f"\nGraph (k={k}):")
                # Partition graph based on cluster-linking constraints
                partitioned_graph, clusters = partition_graph(
                    joint_graphs[k],
                    n_entities,
                    mode == 'directed',
                    return_clusters=True)
                # Infer predictions from clusters
                result = analyzeClusters(
                    clusters, test_dictionary, mention_data, k,
                    n_train_mentions if params['transductive'] else 0)
                # Store result
                results[mode].append(result)
                n_graphs_processed += 1

    avg_graph_processing_time = (time.time() -
                                 graph_processing_time) / n_graphs_processed
    avg_per_graph_time = (knn_fetch_time + avg_graph_processing_time) / 60

    execution_time = (time.time() - time_start) / 60
    # Store results
    output_file_name = os.path.join(
        output_path,
        f"eval_results_{__import__('calendar').timegm(__import__('time').gmtime())}"
    )

    try:
        for recall_k in recall_accuracy:
            result_overview[f'recall@{recall_k}'] = recall_accuracy[recall_k]
    except:
        logger.info(
            "Recall data not available since graphs were loaded from disk")

    for mode in results:
        mode_results = results[mode]
        result_overview[mode] = {}
        for r in mode_results:
            k = r['knn_mentions']
            result_overview[mode][f'accuracy@knn{k}'] = r['accuracy']
            logger.info(f"{mode} accuracy@knn{k} = {r['accuracy']}")
            output_file = f'{output_file_name}-{mode}-{k}.json'
            with open(output_file, 'w') as f:
                json.dump(r, f, indent=2)
                print(
                    f"\nPredictions ({mode}) @knn{k} saved at: {output_file}")
    with open(f'{output_file_name}.json', 'w') as f:
        json.dump(result_overview, f, indent=2)
        print(f"\nPredictions overview saved at: {output_file_name}.json")

    logger.info("\nThe avg. per graph evaluation time is {} minutes\n".format(
        avg_per_graph_time))
    logger.info(
        "\nThe total evaluation took {} minutes\n".format(execution_time))