Beispiel #1
0
class BertNerPreprocessor(Component):
    """Takes tokens and splits them into bert subtokens, encode subtokens with their indices.
    Creates mask of subtokens (one for first subtoken, zero for later subtokens).

    If tags are provided, calculate tags for subtokens.

    Args:
        vocab_file: path to vocabulary
        do_lower_case: set True if lowercasing is needed
        max_seq_length: max sequence length in subtokens, including [SEP] and [CLS] tokens
        max_subword_length: replace token to <unk> if it's length is larger than this
            (defaults to None, which is equal to +infinity)
        token_mask_prob: probability of masking token while training
        provide_subword_tags: output tags for subwords or for words

    Attributes:
        max_seq_length: max sequence length in subtokens, including [SEP] and [CLS] tokens
        max_subword_length: rmax lenght of a bert subtoken
        tokenizer: instance of Bert FullTokenizer
    """
    def __init__(self,
                 vocab_file: str,
                 do_lower_case: bool = True,
                 max_seq_length: int = 512,
                 max_subword_length: int = None,
                 token_maksing_prob: float = 0.0,
                 provide_subword_tags: bool = False,
                 **kwargs):
        self._re_tokenizer = re.compile(r"[\w']+|[^\w ]")
        self.provide_subword_tags = provide_subword_tags
        self.mode = kwargs.get('mode')
        self.max_seq_length = max_seq_length
        self.max_subword_length = max_subword_length
        vocab_file = str(expand_path(vocab_file))
        self.tokenizer = FullTokenizer(vocab_file=vocab_file,
                                       do_lower_case=do_lower_case)
        self.token_maksing_prob = token_maksing_prob

    def __call__(self,
                 tokens: Union[List[List[str]], List[str]],
                 tags: List[List[str]] = None,
                 **kwargs):
        if isinstance(tokens[0], str):
            tokens = [re.findall(self._re_tokenizer, s) for s in tokens]
        subword_tokens, subword_tok_ids, subword_masks, subword_tags = [], [], [], []
        for i in range(len(tokens)):
            toks = tokens[i]
            ys = ['O'] * len(toks) if tags is None else tags[i]
            mask = [int(y != 'X') for y in ys]
            assert len(toks) == len(ys) == len(mask), \
                f"toks({len(toks)}) should have the same length as " \
                f" ys({len(ys)}) and mask({len(mask)}), tokens = {toks}."
            sw_toks, sw_mask, sw_ys = self._ner_bert_tokenize(
                toks,
                mask,
                ys,
                self.tokenizer,
                self.max_subword_length,
                mode=self.mode,
                token_maksing_prob=self.token_maksing_prob)
            if self.max_seq_length is not None:
                if len(sw_toks) > self.max_seq_length:
                    raise RuntimeError(
                        f"input sequence after bert tokenization"
                        f" shouldn't exceed {self.max_seq_length} tokens.")
            subword_tokens.append(sw_toks)
            subword_tok_ids.append(
                self.tokenizer.convert_tokens_to_ids(sw_toks))
            subword_masks.append(sw_mask)
            subword_tags.append(sw_ys)
            assert len(sw_mask) == len(sw_toks) == len(subword_tok_ids[-1]) == len(sw_ys), \
                f"length of mask({len(sw_mask)}), tokens({len(sw_toks)})," \
                f" token ids({len(subword_tok_ids[-1])}) and ys({len(ys)})" \
                f" for tokens = `{toks}` should match"
        subword_tok_ids = zero_pad(subword_tok_ids, dtype=int, padding=0)
        subword_masks = zero_pad(subword_masks, dtype=int, padding=0)
        if tags is not None:
            if self.provide_subword_tags:
                return tokens, subword_tokens, subword_tok_ids, subword_masks, subword_tags
            else:
                nonmasked_tags = [[t for t in ts if t != 'X'] for ts in tags]
                for swts, swids, swms, ts in zip(subword_tokens,
                                                 subword_tok_ids,
                                                 subword_masks,
                                                 nonmasked_tags):
                    if (len(swids) != len(swms)) or (len(ts) != sum(swms)):
                        log.warning(
                            'Not matching lengths of the tokenization!')
                        log.warning(
                            f'Tokens len: {len(swts)}\n Tokens: {swts}')
                        log.warning(
                            f'Masks len: {len(swms)}, sum: {sum(swms)}')
                        log.warning(f'Masks: {swms}')
                        log.warning(f'Tags len: {len(ts)}\n Tags: {ts}')
                return tokens, subword_tokens, subword_tok_ids, subword_masks, nonmasked_tags
        return tokens, subword_tokens, subword_tok_ids, subword_masks

    @staticmethod
    def _ner_bert_tokenize(
        tokens: List[str],
        mask: List[int],
        tags: List[str],
        tokenizer: FullTokenizer,
        max_subword_len: int = None,
        mode: str = None,
        token_maksing_prob: float = 0.0
    ) -> Tuple[List[str], List[int], List[str]]:
        tokens_subword = ['[CLS]']
        mask_subword = [0]
        tags_subword = ['X']
        for token, flag, tag in zip(tokens, mask, tags):
            subwords = tokenizer.tokenize(token)
            if not subwords or \
                    ((max_subword_len is not None) and (len(subwords) > max_subword_len)):
                tokens_subword.append('[UNK]')
                mask_subword.append(flag)
                tags_subword.append(tag)
            else:
                if mode == 'train' and token_maksing_prob > 0.0 and np.random.rand(
                ) < token_maksing_prob:
                    tokens_subword.extend(['[MASK]'] * len(subwords))
                else:
                    tokens_subword.extend(subwords)
                mask_subword.extend([flag] + [0] * (len(subwords) - 1))
                tags_subword.extend([tag] + ['X'] * (len(subwords) - 1))

        tokens_subword.append('[SEP]')
        mask_subword.append(0)
        tags_subword.append('X')
        return tokens_subword, mask_subword, tags_subword
class BertNerPreprocessor(Component):
    """Takes tokens and splits them into bert subtokens, encode subtokens with their indices.
    Creates mask of subtokens (one for first subtoken, zero for later subtokens).
    
    If tags are provided, calculate tags for subtokens.

    Args:
        vocab_file: path to vocabulary
        do_lower_case: set True if lowercasing is needed
        max_seq_length: max sequence length in subtokens, including [SEP] and [CLS] tokens
        max_subword_length: replace token to <unk> if it's length is larger than this
            (defaults to None, which is equal to +infinity)

    Attributes:
        max_seq_length: max sequence length in subtokens, including [SEP] and [CLS] tokens
        max_subword_length: rmax lenght of a bert subtoken
        tokenizer: instance of Bert FullTokenizer
    """
    def __init__(self,
                 vocab_file: str,
                 do_lower_case: bool = True,
                 max_seq_length: int = 512,
                 max_subword_length: int = None,
                 **kwargs):
        self.max_seq_length = max_seq_length
        self.max_subword_length = max_subword_length
        vocab_file = str(expand_path(vocab_file))
        self.tokenizer = FullTokenizer(vocab_file=vocab_file,
                                       do_lower_case=do_lower_case)

    def __call__(self,
                 tokens: List[List[str]],
                 tags: List[List[str]] = None,
                 **kwargs):
        subword_tokens, subword_tok_ids, subword_masks, subword_tags = [], [], [], []
        for i in range(len(tokens)):
            toks = tokens[i]
            ys = ['X'] * len(toks) if tags is None else tags[i]
            assert len(toks) == len(ys), \
                f"toks({len(toks)}) should have the same length as "\
                f" ys({len(ys)}), tokens = {toks}."
            sw_toks, sw_mask, sw_ys = self._ner_bert_tokenize(
                toks, [1] * len(toks), ys, self.tokenizer,
                self.max_subword_length)
            if self.max_seq_length is not None:
                sw_toks = sw_toks[:self.max_seq_length]
                sw_mask = sw_mask[:self.max_seq_length]
                sw_ys = sw_ys[:self.max_seq_length]

                # add [sep] if we cut it
                if sw_toks[-1] != '[SEP]':
                    sw_toks[-1] = '[SEP]'
                    sw_mask[-1] = 0
                    sw_ys[-1] = 'X'
            subword_tokens.append(sw_toks)
            subword_tok_ids.append(
                self.tokenizer.convert_tokens_to_ids(sw_toks))
            subword_masks.append(sw_mask)
            subword_tags.append(sw_ys)
            assert len(sw_mask) == len(sw_toks) == len(subword_tok_ids[-1]) == len(sw_ys),\
                f"length of mask({len(sw_mask)}), tokens({len(sw_toks)}),"\
                f" token ids({len(subword_tok_ids[-1])}) and ys({len(ys)})"\
                f" for tokens = `{toks}` should match"
        subword_tok_ids = zero_pad(subword_tok_ids, dtype=int, padding=0)
        subword_masks = zero_pad(subword_masks, dtype=int, padding=0)
        if tags is not None:
            return subword_tokens, subword_tok_ids, subword_masks, subword_tags
        return subword_tokens, subword_tok_ids, subword_masks

    @staticmethod
    def _ner_bert_tokenize(
            tokens: List[str],
            mask: List[int],
            tags: List[str],
            tokenizer: FullTokenizer,
            max_subword_len: int = None) -> Tuple[List[str], List[str]]:
        tokens_subword = ['[CLS]']
        mask_subword = [0]
        tags_subword = ['X']

        for token, flag, tag in zip(tokens, mask, tags):
            subwords = tokenizer.tokenize(token)
            if not subwords or\
                    ((max_subword_len is not None) and (len(subwords) > max_subword_len)):
                tokens_subword.append('[UNK]')
                mask_subword.append(0)
                tags_subword.append('X')
            else:
                tokens_subword.extend(subwords)
                mask_subword.extend([flag] + [0] * (len(subwords) - 1))
                tags_subword.extend([tag] + ['X'] * (len(subwords) - 1))

        tokens_subword.append('[SEP]')
        mask_subword.append(0)
        tags_subword.append('X')
        return tokens_subword, mask_subword, tags_subword