Esempio n. 1
0
    def test_full_tokenizer(self):
        vocab_tokens = [
            "[UNK]",
            "[CLS]",
            "[SEP]",
            "want",
            "##want",
            "##ed",
            "wa",
            "un",
            "runn",
            "##ing",
            ",",
            "low",
            "lowest",
        ]
        with TemporaryDirectory() as tmpdirname:
            vocab_file = os.path.join(tmpdirname,
                                      VOCAB_FILES_NAMES["vocab_file"])
            with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
                vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))

            input_text = "UNwant\u00E9d,running"
            output_text = "unwanted, running"

            create_and_check_tokenizer_commons(self, input_text, output_text,
                                               BertTokenizer, tmpdirname)

            tokenizer = BertTokenizer(vocab_file)

            tokens = tokenizer.tokenize("UNwant\u00E9d,running")
            self.assertListEqual(
                tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
            self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens),
                                 [7, 4, 5, 10, 8, 9])
Esempio n. 2
0
    def test_full_tokenizer(self):
        tokenizer = BertTokenizer(self.vocab_file)

        tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
        self.assertListEqual(tokens,
                             ["un", "##want", "##ed", ",", "runn", "##ing"])
        self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens),
                             [7, 4, 5, 10, 8, 9])
def encode_documents(documents: list, tokenizer: BertTokenizer, max_input_length=512):
    """
    Returns a len(documents) * max_sequences_per_document * 3 * 512 tensor where len(documents) is the batch
    dimension and the others encode bert input.

    This is the input to any of the document bert architectures.

    :param documents: a list of text documents
    :param tokenizer: the sentence piece bert tokenizer
    :return:
    """
    tokenized_documents = [tokenizer.tokenize(document) for document in documents]
    max_sequences_per_document = math.ceil(max(len(x)/(max_input_length-2) for x in tokenized_documents))
    assert max_sequences_per_document <= 300, "Your document is to large, arbitrary size when writing"

    output = torch.zeros(size=(len(documents), max_sequences_per_document, 3, 512), dtype=torch.long)
    document_seq_lengths = [] #number of sequence generated per document
    #Need to use 510 to account for 2 padding tokens
    for doc_index, tokenized_document in enumerate(tokenized_documents):
        max_seq_index = 0
        for seq_index, i in enumerate(range(0, len(tokenized_document), (max_input_length-2))):
            raw_tokens = tokenized_document[i:i+(max_input_length-2)]
            tokens = []
            input_type_ids = []

            tokens.append("[CLS]")
            input_type_ids.append(0)
            for token in raw_tokens:
                tokens.append(token)
                input_type_ids.append(0)
            tokens.append("[SEP]")
            input_type_ids.append(0)

            input_ids = tokenizer.convert_tokens_to_ids(tokens)
            attention_masks = [1] * len(input_ids)

            while len(input_ids) < max_input_length:
                input_ids.append(0)
                input_type_ids.append(0)
                attention_masks.append(0)

            assert len(input_ids) == 512 and len(attention_masks) == 512 and len(input_type_ids) == 512

            #we are ready to rumble
            output[doc_index][seq_index] = torch.cat((torch.LongTensor(input_ids).unsqueeze(0),
                                                           torch.LongTensor(input_type_ids).unsqueeze(0),
                                                           torch.LongTensor(attention_masks).unsqueeze(0)),
                                                          dim=0)
            max_seq_index = seq_index
        document_seq_lengths.append(max_seq_index+1)
    return output, torch.LongTensor(document_seq_lengths)
Esempio n. 4
0
    def tokenize_one_(self, msg: str, tokenizer: BertTokenizer = None):
        bert_tokens = []
        wp_starts = []
        truncated = False
        if tokenizer is None:
            raise Exception('Tokenizer can not be None.')

        tokens = pre_processing.tokenise(f'{msg["title"]}. {msg["body"]}',
                                         lowercase=False,
                                         simple=True,
                                         remove_stopwords=False)
        for i_token, token_str in enumerate(tokens):
            skip_token = False
            wordpieces = tokenizer.tokenize(token_str)

            if not wordpieces:
                # this mainly happens for strange unicode characters
                token_str = '[UNK]'
                wordpieces = tokenizer.tokenize(token_str)
                skip_token = True

            if len(bert_tokens) + len(wordpieces) > 510:
                # bert model is limited to 512 tokens
                truncated = True
                break

            if not skip_token:
                wp_starts.append(len(bert_tokens) + 1)  # first token is [CLS]

            bert_tokens.extend(wordpieces)
        bert_tokens = ['[CLS]'] + bert_tokens + ['[SEP]']

        assert len(bert_tokens) <= 512, f'{len(bert_tokens)} > 512'

        bert_ids = tokenizer.convert_tokens_to_ids(bert_tokens)
        return bert_ids, wp_starts, truncated
def convert_examples_to_spans(
    examples: List[Example],
    ner_tags_converter: NERTagsEncoder,
    tokenizer: BertTokenizer,
    max_seq_length: int,
    doc_stride: int,
    is_training: bool,
    unique_id_start: Optional[int] = None,
    verbose: bool = True,
) -> List[InputSpan]:
    """Converts examples to BERT input-ready data tensor-like structures,
    splitting large documents into spans of `max_seq_length` using a stride of
    `doc_stride` tokens."""

    unique_id = unique_id_start or 1000000000

    features = []
    for (example_index, example) in enumerate(examples):

        doc_tokens = example.doc_tokens
        doc_labels = example.labels

        tok_to_orig_index = []
        orig_to_tok_index = []
        all_doc_tokens = []
        all_doc_labels = []
        all_prediction_mask = []

        for i, token in enumerate(doc_tokens):
            orig_to_tok_index.append(len(all_doc_tokens))
            sub_tokens = tokenizer.tokenize(token.text)
            for j, sub_token in enumerate(sub_tokens):
                # Create mapping from subtokens to original token
                tok_to_orig_index.append(i)
                all_doc_tokens.append(sub_token)
                # Mask all subtokens (j > 0)
                all_prediction_mask.append(j == 0)

                if j == 0:
                    label = doc_labels[i]
                    all_doc_labels.append(label)
                else:
                    all_doc_labels.append('X')

        assert len(all_doc_tokens) == len(all_prediction_mask)
        if is_training:
            assert len(all_doc_tokens) == len(all_doc_labels)

        # The -1 accounts for [CLS]. For NER we have only one sentence, so no
        # [SEP] tokens.
        max_tokens_for_doc = max_seq_length - 1

        # We can have documents that are longer than the maximum sequence length.
        # To deal with this we do a sliding window approach, where we take chunks
        # of the up to our max length with a stride of `doc_stride`.
        _DocSpan = collections.namedtuple(  # pylint: disable=invalid-name
            "DocSpan", ["start", "length"])
        doc_spans = []
        start_offset = 0
        while start_offset < len(all_doc_tokens):
            length = len(all_doc_tokens) - start_offset
            if length > max_tokens_for_doc:
                length = max_tokens_for_doc
            doc_spans.append(_DocSpan(start=start_offset, length=length))
            if start_offset + length == len(all_doc_tokens):
                break
            start_offset += min(length, doc_stride)

        for (doc_span_index, doc_span) in enumerate(doc_spans):
            tokens = []
            token_to_orig_map = {}
            token_is_max_context = []
            segment_ids = []
            labels = None
            label_ids = None
            prediction_mask = []
            # Include [CLS] token
            tokens.append("[CLS]")
            segment_ids.append(0)
            prediction_mask.append(False)

            # Ignore [CLS] label
            if is_training:
                labels = ['X']

            for i in range(doc_span.length):
                # Each doc span will have a dict that indicates if it is the
                # *max_context span* for the tokens inside it
                split_token_index = doc_span.start + i
                token_to_orig_map[len(
                    tokens)] = tok_to_orig_index[split_token_index]

                is_max_context = _check_is_max_context(doc_spans,
                                                       doc_span_index,
                                                       split_token_index)
                token_is_max_context.append(is_max_context)
                tokens.append(all_doc_tokens[split_token_index])
                segment_ids.append(0)
                if is_training:
                    labels.append(all_doc_labels[split_token_index])
                prediction_mask.append(all_prediction_mask[split_token_index])

            input_ids = tokenizer.convert_tokens_to_ids(tokens)
            if is_training:
                label_ids = ner_tags_converter.convert_tags_to_ids(labels)

            # The mask has 1 for real tokens and 0 for padding tokens. Only real
            # tokens are attended to.
            input_mask = [1] * len(input_ids)

            # Zero-pad up to the sequence length.
            while len(input_ids) < max_seq_length:
                input_ids.append(0)
                input_mask.append(0)
                segment_ids.append(0)
                if is_training:
                    label_ids.append(ner_tags_converter.ignore_index)
                prediction_mask.append(False)

            # If not training, use placeholder labels
            if not is_training:
                labels = ['O'] * len(input_ids)
                label_ids = [ner_tags_converter.ignore_index] * len(input_ids)

            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length
            assert len(prediction_mask) == max_seq_length
            if is_training:
                assert len(label_ids) == max_seq_length

            if verbose and example_index < 20:
                LOGGER.info("*** Example ***")
                LOGGER.info("unique_id: %s" % (unique_id))
                LOGGER.info("example_index: %s" % (example_index))
                LOGGER.info("doc_span_index: %s" % (doc_span_index))
                LOGGER.info("tokens: %s" % " ".join(tokens))
                LOGGER.info("token_to_orig_map: %s" % " ".join(
                    ["%d:%d" % (x, y)
                     for (x, y) in token_to_orig_map.items()]))
                LOGGER.info("token_is_max_context: %s", token_is_max_context)
                LOGGER.info("input_ids: %s" %
                            " ".join([str(x) for x in input_ids]))
                LOGGER.info("input_mask: %s" %
                            " ".join([str(x) for x in input_mask]))
                LOGGER.info("segment_ids: %s" %
                            " ".join([str(x) for x in segment_ids]))
                LOGGER.info("prediction_mask: %s" %
                            " ".join([str(x) for x in prediction_mask]))
                if is_training:
                    LOGGER.info("label_ids: %s" %
                                " ".join([str(x) for x in label_ids]))

                LOGGER.info("tags:")
                inside_label = False
                for tok, lab, lab_id in zip(tokens, labels, label_ids):
                    if lab[0] == "O":
                        if inside_label and tok.startswith("##"):
                            LOGGER.info(f'{tok}\tX')
                        else:
                            inside_label = False
                    else:
                        if lab[0] in ("B", "I", "L", "U") or inside_label:
                            if lab[0] in ("B", "U"):
                                # new entity
                                LOGGER.info('')
                            inside_label = True
                            LOGGER.info(f'{tok}\t{lab}\t{lab_id}')

            features.append(
                InputSpan(
                    unique_id=unique_id,
                    example_index=example_index,
                    doc_span_index=doc_span_index,
                    tokens=tokens,
                    token_to_orig_map=token_to_orig_map,
                    token_is_max_context=token_is_max_context,
                    input_ids=input_ids,
                    input_mask=input_mask,
                    segment_ids=segment_ids,
                    labels=labels,
                    label_ids=label_ids,
                    prediction_mask=prediction_mask,
                ))
            unique_id += 1

    return features
    taggerOne_pred_mention_labels = {}
    with open(PRED_PUBTATOR_FILE, 'r') as f:
        for line in f:
            line_split = line.strip().split('\t')
            if len(line_split) == 6:
                if line_split[0] not in test_pmids:
                    continue
                pred_key = (line_split[0], line_split[1], line_split[2])
                taggerOne_pred_mention_types[pred_key] = line_split[4]
                taggerOne_pred_mention_labels[pred_key] = line_split[
                    5].replace('UMLS:', '')

    # tokenize all of the documents
    tokenized_docs = {}
    for pmid, raw_text in raw_docs.items():
        wp_tokens = tokenizer.tokenize(raw_text)
        tokenized_text = ' '.join(wp_tokens).replace(' ##', '')
        tokenized_docs[pmid] = tokenized_text

    # get all of the mentions and their tfidf candidates in raw form
    print('Reading pred mentions and tfidf candidates...')
    pred_mention_cands = defaultdict(list)
    with open(PRED_MATCHES_FILE, 'r') as f:
        reader = csv.reader(f, delimiter="\t", quotechar='"')
        keys = next(reader)
        for row in tqdm(reader):
            if row[0] not in test_pmids:
                continue
            pred_mention_key = (row[0], row[1], row[2])
            pred_mention_cand_val = {k: v for k, v in zip(keys, row)}
            pred_mention_cands[pred_mention_key].append(pred_mention_cand_val)