def test_full_tokenizer(self):
        vocab_tokens = [
            "[UNK]",
            "[CLS]",
            "[SEP]",
            "want",
            "##want",
            "##ed",
            "wa",
            "un",
            "runn",
            "##ing",
            ",",
            "low",
            "lowest",
        ]
        with TemporaryDirectory() as tmpdirname:
            vocab_file = os.path.join(tmpdirname,
                                      VOCAB_FILES_NAMES['vocab_file'])
            with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
                vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))

            input_text = u"UNwant\u00E9d,running"
            output_text = u"unwanted, running"

            create_and_check_tokenizer_commons(self, input_text, output_text,
                                               BertTokenizer, tmpdirname)

            tokenizer = BertTokenizer(vocab_file)

            tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
            self.assertListEqual(
                tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
            self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens),
                                 [7, 4, 5, 10, 8, 9])
Ejemplo n.º 2
0
    def __init__(self,
                 path,
                 tokenizer: BertTokenizer,
                 max_seq_length,
                 readin: int = 2000000,
                 dupe_factor: int = 5,
                 small_seq_prob: float = 0.1):
        self.dupe_factor = dupe_factor
        self.max_seq_length = max_seq_length
        self.small_seq_prob = small_seq_prob

        documents = []
        instances = []
        with open(path, encoding='utf-8') as fd:
            for i, line in enumerate(tqdm(fd)):
                line = line.replace('\n', '')
                # Expected format (Q,T,U,S,D)
                # query, title, url, snippet, document = line.split('\t')
                # ! remove this following line later
                document = line
                if len(document.split("<sep>")) <= 3:
                    continue
                lines = document.split("<sep>")
                document = []
                for seq in lines:
                    document.append(tokenizer.tokenize(seq))
                # document = list(map(tokenizer.tokenize, lines))
                documents.append(document)

        documents = [x for x in documents if x]

        self.documents = documents
        for _ in range(self.dupe_factor):
            for index in range(len(self.documents)):
                instances.extend(self.create_training_instance(index))

        shuffle(instances)
        self.instances = instances
        self.len = len(self.instances)
        self.documents = None
        documents = None
Ejemplo n.º 3
0
    def __init__(self,
                 path,
                 tokenizer: BertTokenizer,
                 max_seq_length: int = 512,
                 readin: int = 2000000,
                 dupe_factor: int = 6,
                 small_seq_prob: float = 0.1):
        self.dupe_factor = dupe_factor
        self.max_seq_length = max_seq_length
        self.small_seq_prob = small_seq_prob

        documents = []
        instances = []
        with open(path, encoding='utf-8') as fd:
            document = []
            for i, line in enumerate(tqdm(fd)):
                line = line.replace('\n', '')
                # document = line
                # if len(document.split("<sep>")) <= 3:
                #     continue
                if len(line) == 0:  # This is end of document
                    documents.append(document)
                    document = []
                if len(line.split(' ')) > 2:
                    document.append(tokenizer.tokenize(line))
            if len(document) > 0:
                documents.append(document)

        documents = [x for x in documents if x]
        print(documents[0])
        print(len(documents))
        self.documents = documents
        for _ in range(self.dupe_factor):
            for index in range(len(self.documents)):
                instances.extend(self.create_training_instance(index))

        shuffle(instances)
        self.instances = instances
        self.len = len(self.instances)
        self.documents = None
        documents = None
Ejemplo n.º 4
0
def encode_documents(documents: list,
                     tokenizer: BertTokenizer,
                     max_input_length=512):
    """
    Returns a len(documents) * max_sequences_per_document * 3 * 512 tensor where len(documents) is the batch
    dimension and the others encode bert input.

    This is the input to any of the document bert architectures.

    :param documents: a list of text documents
    :param tokenizer: the sentence piece bert tokenizer
    :return:
    """
    tokenized_documents = [
        tokenizer.tokenize(document)[:10200] for document in documents
    ]  #added by AD (only take first 10200 tokens of each documents as input)
    max_sequences_per_document = math.ceil(
        max(len(x) / (max_input_length - 2) for x in tokenized_documents))
    assert max_sequences_per_document <= 20, "Your document is to large, arbitrary size when writing"

    output = torch.zeros(size=(len(documents), max_sequences_per_document, 3,
                               512),
                         dtype=torch.long)

    #for distilbert ( distilbert can not work with empty sequences. Therefore, we replace empty sequences with '[CLS]', '[SEP]', 0, 0, 0, .... ):
    for doc_id in range(len(documents)):
        for seq_id in range(max_sequences_per_document):
            output[doc_id, seq_id, 0] = torch.LongTensor(
                tokenizer.convert_tokens_to_ids(['[CLS]', '[SEP]']) + [0] *
                (512 - 2))  #input_ids
            output[doc_id, seq_id, 2] = torch.LongTensor(
                [1] * 2 + [0] * (512 - 2))  #attention_mask

    document_seq_lengths = []  #number of sequence generated per document
    #Need to use 510 to account for 2 padding tokens
    for doc_index, tokenized_document in enumerate(tokenized_documents):
        max_seq_index = 0
        for seq_index, i in enumerate(
                range(0, len(tokenized_document), (max_input_length - 2))):
            raw_tokens = tokenized_document[i:i + (max_input_length - 2)]
            tokens = []
            input_type_ids = []

            tokens.append("[CLS]")
            input_type_ids.append(0)
            for token in raw_tokens:
                tokens.append(token)
                input_type_ids.append(0)
            tokens.append("[SEP]")
            input_type_ids.append(0)

            input_ids = tokenizer.convert_tokens_to_ids(tokens)
            attention_masks = [1] * len(input_ids)

            while len(input_ids) < max_input_length:
                input_ids.append(0)
                input_type_ids.append(0)
                attention_masks.append(0)

            assert len(input_ids) == 512 and len(
                attention_masks) == 512 and len(input_type_ids) == 512

            #we are ready to rumble
            output[doc_index][seq_index] = torch.cat(
                (torch.LongTensor(input_ids).unsqueeze(0),
                 torch.LongTensor(input_type_ids).unsqueeze(0),
                 torch.LongTensor(attention_masks).unsqueeze(0)),
                dim=0)
            max_seq_index = seq_index
        document_seq_lengths.append(max_seq_index + 1)
    return output, torch.LongTensor(document_seq_lengths)