def document_encoder(title: str, doc_sents: list, tokenizer: LongformerTokenizer): title_res = SPECIAL_TITLE_START + title + SPECIAL_TITLE_END ## title_tokens = tokenizer.tokenize(text=title_res) title_encode_ids = tokenizer.encode(text=title_tokens, add_special_tokens=False) assert len(title_tokens) == len(title_encode_ids) title_len = len(title_encode_ids) ##++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ encode_id_lens = [] encode_id_lens.append(title_len) doc_encode_id_list = [] doc_encode_id_list.append(title_encode_ids) for sent_idx, sent_text in enumerate(doc_sents): sent_text_res = sent_text + SPECIAL_SENTENCE_TOKEN sent_tokens = tokenizer.tokenize(text=sent_text_res) sent_encode_ids = tokenizer.encode(text=sent_tokens, add_special_tokens=False) assert len(sent_tokens) == len(sent_encode_ids) ##++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ doc_encode_id_list.append(sent_encode_ids) sent_len = len(sent_encode_ids) encode_id_lens.append(sent_len) ##++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ doc_sent_len_cum_list = list( itertools.accumulate(encode_id_lens, operator.add)) sent_start_end_pair = [(doc_sent_len_cum_list[i], doc_sent_len_cum_list[i + 1] - 1) for i in range(len(encode_id_lens) - 1)] doc_encode_ids = list(itertools.chain.from_iterable(doc_encode_id_list)) assert len(doc_encode_ids) == doc_sent_len_cum_list[-1] assert len(sent_start_end_pair) == len(doc_sents) return doc_encode_ids, sent_start_end_pair, len(doc_encode_ids), title_len
def query_encoder(query: str, tokenizer: LongformerTokenizer): query_res = CLS_TOKEN + SPECIAL_QUERY_START + query + SPECIAL_QUERY_END query_tokens = tokenizer.tokenize(text=query_res) query_encode_ids = tokenizer.encode(text=query_tokens, add_special_tokens=False) assert len(query_tokens) == len(query_encode_ids) query_len = len(query_encode_ids) return query_encode_ids, query_len