Exemplo n.º 1
0
def document_encoder(title: str, doc_sents: list,
                     tokenizer: LongformerTokenizer):
    title_res = SPECIAL_TITLE_START + title + SPECIAL_TITLE_END  ##
    title_tokens = tokenizer.tokenize(text=title_res)
    title_encode_ids = tokenizer.encode(text=title_tokens,
                                        add_special_tokens=False)
    assert len(title_tokens) == len(title_encode_ids)
    title_len = len(title_encode_ids)
    ##++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    encode_id_lens = []
    encode_id_lens.append(title_len)
    doc_encode_id_list = []
    doc_encode_id_list.append(title_encode_ids)
    for sent_idx, sent_text in enumerate(doc_sents):
        sent_text_res = sent_text + SPECIAL_SENTENCE_TOKEN
        sent_tokens = tokenizer.tokenize(text=sent_text_res)
        sent_encode_ids = tokenizer.encode(text=sent_tokens,
                                           add_special_tokens=False)
        assert len(sent_tokens) == len(sent_encode_ids)
        ##++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        doc_encode_id_list.append(sent_encode_ids)
        sent_len = len(sent_encode_ids)
        encode_id_lens.append(sent_len)
        ##++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    doc_sent_len_cum_list = list(
        itertools.accumulate(encode_id_lens, operator.add))
    sent_start_end_pair = [(doc_sent_len_cum_list[i],
                            doc_sent_len_cum_list[i + 1] - 1)
                           for i in range(len(encode_id_lens) - 1)]
    doc_encode_ids = list(itertools.chain.from_iterable(doc_encode_id_list))
    assert len(doc_encode_ids) == doc_sent_len_cum_list[-1]
    assert len(sent_start_end_pair) == len(doc_sents)
    return doc_encode_ids, sent_start_end_pair, len(doc_encode_ids), title_len
Exemplo n.º 2
0
def query_encoder(query: str, tokenizer: LongformerTokenizer):
    query_res = CLS_TOKEN + SPECIAL_QUERY_START + query + SPECIAL_QUERY_END
    query_tokens = tokenizer.tokenize(text=query_res)
    query_encode_ids = tokenizer.encode(text=query_tokens, add_special_tokens=False)
    assert len(query_tokens) == len(query_encode_ids)
    query_len = len(query_encode_ids)
    return query_encode_ids, query_len