def test_full_tokenizer(self):
        vocab_tokens = [
            "[UNK]",
            "[CLS]",
            "[SEP]",
            "want",
            "##want",
            "##ed",
            "wa",
            "un",
            "runn",
            "##ing",
            ",",
            "low",
            "lowest",
        ]
        with TemporaryDirectory() as tmpdirname:
            vocab_file = os.path.join(tmpdirname,
                                      VOCAB_FILES_NAMES['vocab_file'])
            with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
                vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))

            input_text = u"UNwant\u00E9d,running"
            output_text = u"unwanted, running"

            create_and_check_tokenizer_commons(self, input_text, output_text,
                                               BertTokenizer, tmpdirname)

            tokenizer = BertTokenizer(vocab_file)

            tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
            self.assertListEqual(
                tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
            self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens),
                                 [7, 4, 5, 10, 8, 9])
Example #2
0
def encode_documents(documents: list,
                     tokenizer: BertTokenizer,
                     max_input_length=512):
    """
    Returns a len(documents) * max_sequences_per_document * 3 * 512 tensor where len(documents) is the batch
    dimension and the others encode bert input.

    This is the input to any of the document bert architectures.

    :param documents: a list of text documents
    :param tokenizer: the sentence piece bert tokenizer
    :return:
    """
    tokenized_documents = [
        tokenizer.tokenize(document)[:10200] for document in documents
    ]  #added by AD (only take first 10200 tokens of each documents as input)
    max_sequences_per_document = math.ceil(
        max(len(x) / (max_input_length - 2) for x in tokenized_documents))
    assert max_sequences_per_document <= 20, "Your document is to large, arbitrary size when writing"

    output = torch.zeros(size=(len(documents), max_sequences_per_document, 3,
                               512),
                         dtype=torch.long)

    #for distilbert ( distilbert can not work with empty sequences. Therefore, we replace empty sequences with '[CLS]', '[SEP]', 0, 0, 0, .... ):
    for doc_id in range(len(documents)):
        for seq_id in range(max_sequences_per_document):
            output[doc_id, seq_id, 0] = torch.LongTensor(
                tokenizer.convert_tokens_to_ids(['[CLS]', '[SEP]']) + [0] *
                (512 - 2))  #input_ids
            output[doc_id, seq_id, 2] = torch.LongTensor(
                [1] * 2 + [0] * (512 - 2))  #attention_mask

    document_seq_lengths = []  #number of sequence generated per document
    #Need to use 510 to account for 2 padding tokens
    for doc_index, tokenized_document in enumerate(tokenized_documents):
        max_seq_index = 0
        for seq_index, i in enumerate(
                range(0, len(tokenized_document), (max_input_length - 2))):
            raw_tokens = tokenized_document[i:i + (max_input_length - 2)]
            tokens = []
            input_type_ids = []

            tokens.append("[CLS]")
            input_type_ids.append(0)
            for token in raw_tokens:
                tokens.append(token)
                input_type_ids.append(0)
            tokens.append("[SEP]")
            input_type_ids.append(0)

            input_ids = tokenizer.convert_tokens_to_ids(tokens)
            attention_masks = [1] * len(input_ids)

            while len(input_ids) < max_input_length:
                input_ids.append(0)
                input_type_ids.append(0)
                attention_masks.append(0)

            assert len(input_ids) == 512 and len(
                attention_masks) == 512 and len(input_type_ids) == 512

            #we are ready to rumble
            output[doc_index][seq_index] = torch.cat(
                (torch.LongTensor(input_ids).unsqueeze(0),
                 torch.LongTensor(input_type_ids).unsqueeze(0),
                 torch.LongTensor(attention_masks).unsqueeze(0)),
                dim=0)
            max_seq_index = seq_index
        document_seq_lengths.append(max_seq_index + 1)
    return output, torch.LongTensor(document_seq_lengths)