def __call__(self,
                 tokens: Union[List[List[str]], List[str]],
                 tags: List[List[str]] = None,
                 **kwargs):
        if isinstance(tokens[0], str):
            tokens = [re.findall(self._re_tokenizer, s) for s in tokens]
        subword_tokens, subword_tok_ids, startofword_markers, subword_tags = [], [], [], []
        for i in range(len(tokens)):
            toks = tokens[i]
            ys = ['O'] * len(toks) if tags is None else tags[i]
            assert len(toks) == len(ys), \
                f"toks({len(toks)}) should have the same length as ys({len(ys)})"
            sw_toks, sw_marker, sw_ys = \
                self._ner_bert_tokenize(toks,
                                        ys,
                                        self.tokenizer,
                                        self.max_subword_length,
                                        mode=self.mode,
                                        subword_mask_mode=self.subword_mask_mode,
                                        token_masking_prob=self.token_masking_prob)
            if self.max_seq_length is not None:
                if len(sw_toks) > self.max_seq_length:
                    raise RuntimeError(f"input sequence after bert tokenization"
                                       f" shouldn't exceed {self.max_seq_length} tokens.")
            subword_tokens.append(sw_toks)
            subword_tok_ids.append(self.tokenizer.convert_tokens_to_ids(sw_toks))
            startofword_markers.append(sw_marker)
            subword_tags.append(sw_ys)
            assert len(sw_marker) == len(sw_toks) == len(subword_tok_ids[-1]) == len(sw_ys), \
                f"length of sow_marker({len(sw_marker)}), tokens({len(sw_toks)})," \
                f" token ids({len(subword_tok_ids[-1])}) and ys({len(ys)})" \
                f" for tokens = `{toks}` should match"

        subword_tok_ids = zero_pad(subword_tok_ids, dtype=int, padding=0)
        startofword_markers = zero_pad(startofword_markers, dtype=int, padding=0)
        attention_mask = Mask()(subword_tokens)

        if tags is not None:
            if self.provide_subword_tags:
                return tokens, subword_tokens, subword_tok_ids, \
                       attention_mask, startofword_markers, subword_tags
            else:
                nonmasked_tags = [[t for t in ts if t != 'X'] for ts in tags]
                for swts, swids, swms, ts in zip(subword_tokens,
                                                 subword_tok_ids,
                                                 startofword_markers,
                                                 nonmasked_tags):
                    if (len(swids) != len(swms)) or (len(ts) != sum(swms)):
                        log.warning('Not matching lengths of the tokenization!')
                        log.warning(f'Tokens len: {len(swts)}\n Tokens: {swts}')
                        log.warning(f'Markers len: {len(swms)}, sum: {sum(swms)}')
                        log.warning(f'Masks: {swms}')
                        log.warning(f'Tags len: {len(ts)}\n Tags: {ts}')
                return tokens, subword_tokens, subword_tok_ids, \
                       attention_mask, startofword_markers, nonmasked_tags
        return tokens, subword_tokens, subword_tok_ids, startofword_markers, attention_mask
    def __call__(self,
                 tokens_batch,
                 entity_offsets_batch,
                 mentions_batch=None,
                 pages_batch=None):
        token_ids_batch, attention_mask_batch, subw_tokens_batch, entity_subw_indices_batch = [], [], [], []
        if mentions_batch is None:
            mentions_batch = [[] for _ in tokens_batch]
        if pages_batch is None:
            pages_batch = [[] for _ in tokens_batch]

        for tokens, entity_offsets_list, mentions_list, pages_list in zip(
                tokens_batch, entity_offsets_batch, mentions_batch,
                pages_batch):
            tokens_list = []
            tokens_offsets_list = []
            for elem in re.finditer(self._re_tokenizer, tokens):
                tokens_list.append(elem[0])
                tokens_offsets_list.append((elem.start(), elem.end()))

            entity_indices_list = []
            for start_offset, end_offset in entity_offsets_list:
                entity_indices = []
                for ind, (start_tok_offset,
                          end_tok_offset) in enumerate(tokens_offsets_list):
                    if start_tok_offset >= start_offset and end_tok_offset <= end_offset:
                        entity_indices.append(ind)
                if not entity_indices:
                    for ind, (
                            start_tok_offset,
                            end_tok_offset) in enumerate(tokens_offsets_list):
                        if start_tok_offset >= start_offset:
                            entity_indices.append(ind)
                            break
                entity_indices_list.append(set(entity_indices))

            ind = 0
            subw_tokens_list = ["[CLS]"]
            entity_subw_indices_list = [[] for _ in entity_indices_list]
            for n, tok in enumerate(tokens_list):
                subw_tok = self.tokenizer.tokenize(tok)
                subw_tokens_list += subw_tok
                for j in range(len(entity_indices_list)):
                    if n in entity_indices_list[j]:
                        for k in range(len(subw_tok)):
                            entity_subw_indices_list[j].append(ind + k + 1)
                ind += len(subw_tok)
            subw_tokens_list.append("[SEP]")
            subw_tokens_batch.append(subw_tokens_list)

            for n in range(len(entity_subw_indices_list)):
                entity_subw_indices_list[n] = sorted(
                    entity_subw_indices_list[n])
            entity_subw_indices_batch.append(entity_subw_indices_list)

        token_ids_batch = [
            self.tokenizer.convert_tokens_to_ids(subw_tokens_list)
            for subw_tokens_list in subw_tokens_batch
        ]
        token_ids_batch = zero_pad(token_ids_batch, dtype=int, padding=0)
        attention_mask_batch = Mask()(subw_tokens_batch)

        return token_ids_batch, attention_mask_batch, entity_subw_indices_batch