def get_candidate_pool_tensor( entity_desc_list, tokenizer, max_seq_length, logger, ): # TODO: add multiple thread process logger.info("Convert candidate text to id") cand_pool = [] for entity_desc in tqdm(entity_desc_list): if type(entity_desc) is tuple: title, entity_text = entity_desc else: title = None entity_text = entity_desc rep = data.get_candidate_representation( entity_text, tokenizer, max_seq_length, title, ) cand_pool.append(rep["ids"]) cand_pool = torch.LongTensor(cand_pool) return cand_pool
def prepare_crossencoder_candidates(tokenizer, labels, nns, id2title, id2text, max_cand_length=128, topk=100): START_TOKEN = tokenizer.cls_token END_TOKEN = tokenizer.sep_token candidate_input_list = [] # samples X topk=10 X 128 label_input_list = [] # samples idx = 0 for label, nn in zip(labels, nns): candidates = [] label_id = -1 for jdx, candidate_id in enumerate(nn[:topk]): if label == candidate_id: label_id = jdx rep = data.get_candidate_representation( id2text[candidate_id], tokenizer, max_cand_length, id2title[candidate_id], ) tokens_ids = rep["ids"] assert len(tokens_ids) == max_cand_length candidates.append(tokens_ids) label_input_list.append(label_id) candidate_input_list.append(candidates) idx += 1 sys.stdout.write("{}/{} \r".format(idx, len(labels))) sys.stdout.flush() label_input_list = np.asarray(label_input_list) candidate_input_list = np.asarray(candidate_input_list) return label_input_list, candidate_input_list