Exemplo n.º 1
0
 def __call__(self,
              tokens: Union[List[List[str]], List[str]],
              tags: List[List[str]] = None,
              **kwargs):
     if isinstance(tokens[0], str):
         tokens = [re.findall(self._re_tokenizer, s) for s in tokens]
     subword_tokens, subword_tok_ids, subword_masks, subword_tags = [], [], [], []
     for i in range(len(tokens)):
         toks = tokens[i]
         ys = ['O'] * len(toks) if tags is None else tags[i]
         mask = [int(y != 'X') for y in ys]
         assert len(toks) == len(ys) == len(mask), \
             f"toks({len(toks)}) should have the same length as " \
             f" ys({len(ys)}) and mask({len(mask)}), tokens = {toks}."
         sw_toks, sw_mask, sw_ys = self._ner_bert_tokenize(
             toks,
             mask,
             ys,
             self.tokenizer,
             self.max_subword_length,
             mode=self.mode,
             token_maksing_prob=self.token_maksing_prob)
         if self.max_seq_length is not None:
             if len(sw_toks) > self.max_seq_length:
                 raise RuntimeError(
                     f"input sequence after bert tokenization"
                     f" shouldn't exceed {self.max_seq_length} tokens.")
         subword_tokens.append(sw_toks)
         subword_tok_ids.append(
             self.tokenizer.convert_tokens_to_ids(sw_toks))
         subword_masks.append(sw_mask)
         subword_tags.append(sw_ys)
         assert len(sw_mask) == len(sw_toks) == len(subword_tok_ids[-1]) == len(sw_ys), \
             f"length of mask({len(sw_mask)}), tokens({len(sw_toks)})," \
             f" token ids({len(subword_tok_ids[-1])}) and ys({len(ys)})" \
             f" for tokens = `{toks}` should match"
     subword_tok_ids = zero_pad(subword_tok_ids, dtype=int, padding=0)
     subword_masks = zero_pad(subword_masks, dtype=int, padding=0)
     if tags is not None:
         if self.provide_subword_tags:
             return tokens, subword_tokens, subword_tok_ids, subword_masks, subword_tags
         else:
             nonmasked_tags = [[t for t in ts if t != 'X'] for ts in tags]
             for swts, swids, swms, ts in zip(subword_tokens,
                                              subword_tok_ids,
                                              subword_masks,
                                              nonmasked_tags):
                 if (len(swids) != len(swms)) or (len(ts) != sum(swms)):
                     log.warning(
                         'Not matching lengths of the tokenization!')
                     log.warning(
                         f'Tokens len: {len(swts)}\n Tokens: {swts}')
                     log.warning(
                         f'Masks len: {len(swms)}, sum: {sum(swms)}')
                     log.warning(f'Masks: {swms}')
                     log.warning(f'Tags len: {len(ts)}\n Tags: {ts}')
             return tokens, subword_tokens, subword_tok_ids, subword_masks, nonmasked_tags
     return tokens, subword_tokens, subword_tok_ids, subword_masks
Exemplo n.º 2
0
    def __call__(self, batch: List[List[str]], tags_batch: Optional[List[List[str]]] = None, mean: bool = None,
                 *args, **kwargs) -> List[Union[list, np.ndarray]]:
        """
        Infer on the given data

        Args:
            batch: tokenized text samples
            tags_batch: optional batch of corresponding tags
            mean: whether to return mean token embedding (does not depend on self.mean)
            *args: additional arguments
            **kwargs: additional arguments

        Returns:

        """

        if self.tags_vocab:
            if tags_batch is None:
                raise ConfigError("TfidfWeightedEmbedder got 'tags_vocab_path' but __call__ did not get tags_batch.")
            batch = [self._tags_encode(sample, tags_sample, mean=mean) for sample, tags_sample in zip(batch, tags_batch)]
        else:
            if tags_batch:
                raise ConfigError("TfidfWeightedEmbedder got tags batch, but 'tags_vocab_path' is empty.")
            batch = [self._encode(sample, mean=mean) for sample in batch]

        if self.pad_zero:
            batch = zero_pad(batch)

        return batch
Exemplo n.º 3
0
    def __call__(self, batch: List[List[int]], **kwargs) -> Union[List[List[np.ndarray]], List[np.ndarray]]:
        """
        Convert given batch of list of labels to one-hot representation of the batch.

        Args:
            batch: list of samples, where each sample is a list of integer labels.
            **kwargs: additional arguments

        Returns:
            if ``single_vector``, list of one-hot representations of each sample,
            otherwise, list of lists of one-hot representations of each label in a sample
        """
        one_hotted_batch = []

        for utt in batch:
            if isinstance(utt, list):
                one_hotted_utt = self._to_one_hot(utt, self._depth)
            elif isinstance(utt, int):
                if self._pad_zeros or self.single_vector:
                    one_hotted_utt = self._to_one_hot([utt], self._depth)
                else:
                    one_hotted_utt = self._to_one_hot([utt], self._depth).reshape(-1)

            if self.single_vector:
                one_hotted_utt = np.sum(one_hotted_utt, axis=0)

            one_hotted_batch.append(one_hotted_utt)

        if self._pad_zeros:
            one_hotted_batch = zero_pad(one_hotted_batch)
        return one_hotted_batch
Exemplo n.º 4
0
 def __call__(self, tokens_batch, **kwargs):
     cap_batch = []
     max_batch_len = 0
     for utterance in tokens_batch:
         cap_list = []
         max_batch_len = max(max_batch_len, len(utterance))
         for token in utterance:
             cap = np.zeros(4, np.float32)
             # Check the case and produce corresponding one-hot
             if len(token) > 0:
                 if token[0].islower():
                     cap[0] = 1
                 elif len(token) == 1 and token[0].isupper():
                     cap[1] = 1
                 elif len(token) > 1 and token[0].isupper() and any(
                         ch.islower() for ch in token):
                     cap[2] = 1
                 elif all(ch.isupper() for ch in token):
                     cap[3] = 1
             cap_list.append(cap)
         cap_batch.append(cap_list)
     if self.pad_zeros:
         return zero_pad(cap_batch)
     else:
         return cap_batch
Exemplo n.º 5
0
 def __call__(self, batch, **kwargs):
     indices_batch = []
     for sample in batch:
         indices_batch.append([self[token] for token in sample])
     if self._pad_with_zeros and self.is_str_batch(batch):
         indices_batch = zero_pad(indices_batch)
     return indices_batch
Exemplo n.º 6
0
 def __call__(self, batch, **kwargs):
     indices_batch = []
     for sample in batch:
         indices_batch.append([self[token] for token in sample])
     if self._pad_with_zeros and self.is_str_batch(batch):
         indices_batch = zero_pad(indices_batch)
     return indices_batch
Exemplo n.º 7
0
    def __call__(self, batch: List[List[int]],
                 **kwargs) -> Union[List[List[np.ndarray]], List[np.ndarray]]:
        """
        Convert given batch of list of labels to one-hot representation of the batch.

        Args:
            batch: list of samples, where each sample is a list of integer labels.
            **kwargs: additional arguments

        Returns:
            if ``single_vector``, list of one-hot representations of each sample,
            otherwise, list of lists of one-hot representations of each label in a sample
        """
        one_hotted_batch = []

        for utt in batch:
            if isinstance(utt, list):
                one_hotted_utt = self._to_one_hot(utt, self._depth)
            elif isinstance(utt, int):
                if self._pad_zeros or self.single_vector:
                    one_hotted_utt = self._to_one_hot([utt], self._depth)
                else:
                    one_hotted_utt = self._to_one_hot([utt],
                                                      self._depth).reshape(-1)

            if self.single_vector:
                one_hotted_utt = np.sum(one_hotted_utt, axis=0)

            one_hotted_batch.append(one_hotted_utt)

        if self._pad_zeros:
            one_hotted_batch = zero_pad(one_hotted_batch)
        return one_hotted_batch
Exemplo n.º 8
0
    def __call__(self, batch: List[List[str]], *args,
                 **kwargs) -> Union[List[np.ndarray], np.ndarray]:
        """
        Embed sentences from a batch.

        Args:
            batch: A list of tokenized text samples.

        Returns:
            A batch of ELMo embeddings.
        """
        if len(batch) > self.mini_batch_size:
            batch_gen = chunk_generator(batch, self.mini_batch_size)
            elmo_output_values = []
            for mini_batch in batch_gen:
                mini_batch_out = self._mini_batch_fit(mini_batch, *args,
                                                      **kwargs)
                elmo_output_values.extend(mini_batch_out)
        else:
            elmo_output_values = self._mini_batch_fit(batch, *args, **kwargs)

        if self.pad_zero:
            elmo_output_values = zero_pad(elmo_output_values)

        return elmo_output_values
Exemplo n.º 9
0
    def __call__(self, batch: List[List[str]], *args,
                 **kwargs) -> Union[List[np.ndarray], np.ndarray]:
        """
        Embed sentences from a batch.

        Args:
            batch: A list of tokenized text samples.

        Returns:
            A batch of ELMo embeddings.
        """
        if not batch:
            empty_vec = np.zeros(self.dim, dtype=np.float32)
            return [empty_vec] if self.mean else [[empty_vec]]

        filled_batch = []
        for batch_line in batch:
            batch_line = batch_line if batch_line else ['']
            filled_batch.append(batch_line)

        batch = filled_batch

        tokens_length = [len(batch_line) for batch_line in batch]
        tokens_length_max = max(tokens_length)
        batch = [
            batch_line + [''] * (tokens_length_max - len(batch_line))
            for batch_line in batch
        ]

        elmo_outputs = self.sess.run(self.elmo_outputs,
                                     feed_dict={
                                         self.tokens_ph: batch,
                                         self.tokens_length_ph: tokens_length,
                                     })

        if self.mean:
            batch = elmo_outputs['default']

            dim0, dim1 = batch.shape

            if self.dim != dim1:
                batch = np.resize(batch, (dim0, self.dim))
        else:
            batch = elmo_outputs['elmo']

            dim0, dim1, dim2 = batch.shape

            if self.dim != dim2:
                batch = np.resize(batch, (dim0, dim1, self.dim))

            batch = [
                batch_line[:length_line]
                for length_line, batch_line in zip(tokens_length, batch)
            ]

            if self.pad_zero:
                batch = zero_pad(batch)

        return batch
Exemplo n.º 10
0
 def __call__(self, batch, mean=False, *args, **kwargs):
     """
     Embed data
     """
     batch = [self._encode(sample, mean) for sample in batch]
     if self.pad_zero:
         batch = zero_pad(batch)
     return batch
 def __call__(self, batch, **kwargs):
     one_hotted_batch = []
     for utt in batch:
         one_hotted_utt = self._to_one_hot(utt, self._depth)
         one_hotted_batch.append(one_hotted_utt)
     if self._pad_zeros:
         one_hotted_batch = zero_pad(one_hotted_batch)
     return one_hotted_batch
Exemplo n.º 12
0
 def _build_feed_dict(self, input_ids, input_masks, y_masks, 
                      y_head=None, y_dep=None, y_tag=None) -> dict:
     y_masks = np.concatenate([np.ones_like(y_masks[:,:1]), y_masks[:, 1:]], axis=1)
     feed_dict = self._build_basic_feed_dict(input_ids, input_masks, train=(y_head is not None))
     feed_dict[self.y_masks_ph] = y_masks
     if y_head is not None:
         y_head = zero_pad(y_head)
         y_head = np.concatenate([np.zeros_like(y_head[:,:1]), y_head], axis=1)
         y_dep = zero_pad(y_dep)
         y_dep = np.concatenate([np.zeros_like(y_dep[:,:1]), y_dep], axis=1)
         feed_dict.update({self.embeddings_keep_prob_ph: 1.0 - self.embeddings_dropout,
                           self.y_head_ph: y_head,
                           self.y_dep_ph: y_dep})
         if self.predict_tags:
             y_tag = np.concatenate([np.zeros_like(y_tag[:,:1]), y_tag], axis=1)
             feed_dict.update({self.y_tag_ph: y_tag, self.tag_weight_ph: self.tag_weight})
     return feed_dict
Exemplo n.º 13
0
    def __call__(self, batch, is_top=True, **kwargs):
        if isinstance(batch, Iterable) and not isinstance(batch, str):
            looked_up_batch = [self(sample, is_top=False) for sample in batch]
        else:
            return self[batch]
        if is_top and self._pad_with_zeros and not is_str_batch(looked_up_batch):
            looked_up_batch = zero_pad(looked_up_batch)

        return looked_up_batch
Exemplo n.º 14
0
    def __call__(self, batch: List[List[str]],
                 *args, **kwargs) -> Union[List[np.ndarray], np.ndarray]:
        """
        Embed sentences from a batch.

        Args:
            batch: A list of tokenized text samples.

        Returns:
            A batch of ELMo embeddings.
        """
        if not batch:
            empty_vec = np.zeros(self.dim, dtype=np.float32)
            return [empty_vec] if self.mean else [[empty_vec]]

        filled_batch = []
        for batch_line in batch:
            batch_line = batch_line if batch_line else ['']
            filled_batch.append(batch_line)

        batch = filled_batch

        tokens_length = [len(batch_line) for batch_line in batch]
        tokens_length_max = max(tokens_length)
        batch = [batch_line + ['']*(tokens_length_max - len(batch_line)) for batch_line in batch]

        elmo_outputs = self.sess.run(
                                    self.elmo_outputs,
                                    feed_dict=
                                    {
                                        self.tokens_ph: batch,
                                        self.tokens_length_ph: tokens_length,
                                    }
                                    )

        if self.mean:
            batch = elmo_outputs['default']

            dim0, dim1 = batch.shape

            if self.dim != dim1:
                batch = np.resize(batch, (dim0,self.dim))
        else:
            batch = elmo_outputs['elmo']

            dim0, dim1, dim2 = batch.shape

            if self.dim != dim2:
                batch = np.resize(batch, (dim0, dim1, self.dim))

            batch = [batch_line[:length_line] for length_line, batch_line in zip(tokens_length, batch)]

            if self.pad_zero:
                batch = zero_pad(batch)

        return batch
Exemplo n.º 15
0
    def __call__(self, batch, is_top=True, **kwargs):
        if isinstance(batch, Iterable) and not isinstance(batch, str):
            looked_up_batch = [self(sample, is_top=False) for sample in batch]
        else:
            return self[batch]
        if is_top and self._pad_with_zeros and not is_str_batch(
                looked_up_batch):
            looked_up_batch = zero_pad(looked_up_batch)

        return looked_up_batch
Exemplo n.º 16
0
 def __call__(self, batch, mean=False, *args, **kwargs):
     """
     Embed data
     """
     embedded = []
     for n, sample in enumerate(batch):
         embedded.append(self._encode(sample, mean))
     if self.pad_zero:
         embedded = zero_pad(embedded)
     return embedded
Exemplo n.º 17
0
    def _mini_batch_fit(self, batch: List[List[str]], *args,
                        **kwargs) -> Union[List[np.ndarray], np.ndarray]:
        """
        Embed sentences from a batch.

        Args:
            batch: A list of tokenized text samples.

        Returns:
            A batch of ELMo embeddings.
        """
        batch, tokens_length = self._fill_batch(batch)

        elmo_outputs = self.sess.run(self.elmo_outputs,
                                     feed_dict={
                                         self.tokens_ph: batch,
                                         self.tokens_length_ph: tokens_length
                                     })

        if 'default' in self.elmo_output_names:
            elmo_output_values = elmo_outputs['default']
            dim0, dim1 = elmo_output_values.shape
            if self.dim != dim1:
                shape = (dim0, self.dim
                         if isinstance(self.dim, int) else self.dim[0])
                elmo_output_values = np.resize(elmo_output_values, shape)
        else:
            elmo_output_values = [
                elmo_outputs[elmo_output_name]
                for elmo_output_name in self.elmo_output_names
            ]
            elmo_output_values = np.concatenate(elmo_output_values, axis=-1)

            dim0, dim1, dim2 = elmo_output_values.shape
            if self.concat_last_axis and self.dim != dim2:
                shape = (dim0, dim1, self.dim)
                elmo_output_values = np.resize(elmo_output_values, shape)

            elmo_output_values = [
                elmo_output_values_line[:length_line]
                for length_line, elmo_output_values_line in zip(
                    tokens_length, elmo_output_values)
            ]

            if self.pad_zero:
                elmo_output_values = zero_pad(elmo_output_values)

            if not self.concat_last_axis:
                slice_indexes = np.cumsum(self.dim).tolist()[:-1]
                elmo_output_values = [[
                    np.array_split(vec, slice_indexes) for vec in tokens
                ] for tokens in elmo_output_values]

        return elmo_output_values
Exemplo n.º 18
0
    def __call__(self,
                 tokens: List[List[str]],
                 tags: List[List[str]] = None,
                 **kwargs):
        subword_tokens, subword_tok_ids, subword_masks, subword_tags = [], [], [], []
        for i in range(len(tokens)):
            toks = tokens[i]
            ys = ['X'] * len(toks) if tags is None else tags[i]
            assert len(toks) == len(ys), \
                f"toks({len(toks)}) should have the same length as "\
                f" ys({len(ys)}), tokens = {toks}."
            sw_toks, sw_mask, sw_ys = self._ner_bert_tokenize(
                toks, [1] * len(toks), ys, self.tokenizer,
                self.max_subword_length)
            if self.max_seq_length is not None:
                sw_toks = sw_toks[:self.max_seq_length]
                sw_mask = sw_mask[:self.max_seq_length]
                sw_ys = sw_ys[:self.max_seq_length]

                # add [sep] if we cut it
                if sw_toks[-1] != '[SEP]':
                    sw_toks[-1] = '[SEP]'
                    sw_mask[-1] = 0
                    sw_ys[-1] = 'X'
            subword_tokens.append(sw_toks)
            subword_tok_ids.append(
                self.tokenizer.convert_tokens_to_ids(sw_toks))
            subword_masks.append(sw_mask)
            subword_tags.append(sw_ys)
            assert len(sw_mask) == len(sw_toks) == len(subword_tok_ids[-1]) == len(sw_ys),\
                f"length of mask({len(sw_mask)}), tokens({len(sw_toks)}),"\
                f" token ids({len(subword_tok_ids[-1])}) and ys({len(ys)})"\
                f" for tokens = `{toks}` should match"
        subword_tok_ids = zero_pad(subword_tok_ids, dtype=int, padding=0)
        subword_masks = zero_pad(subword_masks, dtype=int, padding=0)
        if tags is not None:
            return subword_tokens, subword_tok_ids, subword_masks, subword_tags
        return subword_tokens, subword_tok_ids, subword_masks
Exemplo n.º 19
0
    def __call__(self, batch: List[List[str]], mean: bool = None) -> List[Union[list, np.ndarray]]:
        """
        Embed sentences from batch

        Args:
            batch: list of tokenized text samples
            mean: whether to return mean embedding of tokens per sample

        Returns:
            embedded batch
        """
        batch = [self._encode(sample, mean) for sample in batch]
        if self.pad_zero:
            batch = zero_pad(batch)
        return batch
Exemplo n.º 20
0
    def __call__(self, batch: List[List[str]], mean: bool = None) -> List[Union[list, np.ndarray]]:
        """
        Embed sentences from batch

        Args:
            batch: list of tokenized text samples
            mean: whether to return mean embedding of tokens per sample

        Returns:
            embedded batch
        """
        batch = [self._encode(sample, mean) for sample in batch]
        if self.pad_zero:
            batch = zero_pad(batch)
        return batch
Exemplo n.º 21
0
    def __call__(self, batch, mean=False, *args, **kwargs):
        """
        Embed sentences from batch
        Args:
            batch: list of tokenized text samples
            mean: whether to return mean embedding of tokens per sample
            *args: arguments
            **kwargs: arguments

        Returns:
            embedded batch
        """
        batch = [self._encode(sample, mean) for sample in batch]
        if self.pad_zero:
            batch = zero_pad(batch)
        return batch
Exemplo n.º 22
0
    def __call__(self, batch: List[List[str]], mean: bool = False, *args, **kwargs) -> List[Union[list, np.ndarray]]:
        """
        Embed sentences from batch

        Args:
            batch: list of tokenized text samples
            mean: whether to return mean embedding of tokens per sample
            *args: arguments
            **kwargs: arguments

        Returns:
            embedded batch
        """
        embedded = []
        for n, sample in enumerate(batch):
            embedded.append(self._encode(sample, mean))
        if self.pad_zero:
            embedded = zero_pad(embedded)
        return embedded
Exemplo n.º 23
0
    def __call__(self,
                 batch: List[List[str]],
                 mean: bool = False,
                 *args,
                 **kwargs) -> List[Union[list, np.ndarray]]:
        """
        Embed sentences from batch

        Args:
            batch: list of tokenized text samples
            mean: whether to return mean embedding of tokens per sample
            *args: arguments
            **kwargs: arguments

        Returns:
            embedded batch
        """
        embedded = []
        for n, sample in enumerate(batch):
            embedded.append(self._encode(sample, mean))
        if self.pad_zero:
            embedded = zero_pad(embedded)
        return embedded
Exemplo n.º 24
0
 def __call__(self, tokens_batch, **kwargs):
     cap_batch = []
     max_batch_len = 0
     for utterance in tokens_batch:
         cap_list = []
         max_batch_len = max(max_batch_len, len(utterance))
         for token in utterance:
             cap = np.zeros(4, np.float32)
             # Check the case and produce corresponding one-hot
             if len(token) > 0:
                 if token[0].islower():
                     cap[0] = 1
                 elif len(token) == 1 and token[0].isupper():
                     cap[1] = 1
                 elif len(token) > 1 and token[0].isupper() and any(ch.islower() for ch in token):
                     cap[2] = 1
                 elif all(ch.isupper() for ch in token):
                     cap[3] = 1
             cap_list.append(cap)
         cap_batch.append(cap_list)
     if self.pad_zeros:
         return zero_pad(cap_batch)
     else:
         return cap_batch
Exemplo n.º 25
0
    def __call__(self, batch: List[List[str]],
                 *args, **kwargs) -> Union[List[np.ndarray], np.ndarray]:
        """
        Embed sentences from a batch.

        Args:
            batch: A list of tokenized text samples.

        Returns:
            A batch of ELMo embeddings.
        """
        if len(batch) > self.mini_batch_size:
            batch_gen = chunk_generator(batch, self.mini_batch_size)
            elmo_output_values = []
            for mini_batch in batch_gen:
                mini_batch_out = self._mini_batch_fit(mini_batch, *args, **kwargs)
                elmo_output_values.extend(mini_batch_out)
        else:
            elmo_output_values = self._mini_batch_fit(batch, *args, **kwargs)

        if self.pad_zero:
            elmo_output_values = zero_pad(elmo_output_values)

        return elmo_output_values
    def __call__(self,
                 tokens: Union[List[List[str]], List[str]],
                 tags: List[List[str]] = None,
                 **kwargs):
        tokens_offsets_batch = [[] for _ in tokens]
        if isinstance(tokens[0], str):
            tokens_batch = []
            tokens_offsets_batch = []
            for s in tokens:
                tokens_list = []
                tokens_offsets_list = []
                for elem in re.finditer(self._re_tokenizer, s):
                    tokens_list.append(elem[0])
                    tokens_offsets_list.append((elem.start(), elem.end()))
                tokens_batch.append(tokens_list)
                tokens_offsets_batch.append(tokens_offsets_list)
            tokens = tokens_batch
        subword_tokens, subword_tok_ids, startofword_markers, subword_tags = [], [], [], []
        for i in range(len(tokens)):
            toks = tokens[i]
            ys = ["O"] * len(toks) if tags is None else tags[i]
            assert len(toks) == len(
                ys
            ), f"toks({len(toks)}) should have the same length as ys({len(ys)})"
            sw_toks, sw_marker, sw_ys = self._ner_bert_tokenize(
                toks,
                ys,
                self.tokenizer,
                self.max_subword_length,
                mode=self.mode,
                subword_mask_mode=self.subword_mask_mode,
                token_masking_prob=self.token_masking_prob,
            )
            if self.max_seq_length is not None:
                if len(sw_toks) > self.max_seq_length:
                    raise RuntimeError(
                        f"input sequence after bert tokenization"
                        f" shouldn't exceed {self.max_seq_length} tokens.")
            subword_tokens.append(sw_toks)
            subword_tok_ids.append(
                self.tokenizer.convert_tokens_to_ids(sw_toks))
            startofword_markers.append(sw_marker)
            subword_tags.append(sw_ys)
            assert len(sw_marker) == len(sw_toks) == len(
                subword_tok_ids[-1]
            ) == len(sw_ys), (
                f"length of sow_marker({len(sw_marker)}), tokens({len(sw_toks)}),"
                f" token ids({len(subword_tok_ids[-1])}) and ys({len(ys)})"
                f" for tokens = `{toks}` should match")

        subword_tok_ids = zero_pad(subword_tok_ids, dtype=int, padding=0)
        startofword_markers = zero_pad(startofword_markers,
                                       dtype=int,
                                       padding=0)
        attention_mask = Mask()(subword_tokens)

        if tags is not None:
            if self.provide_subword_tags:
                return tokens, subword_tokens, subword_tok_ids, attention_mask, startofword_markers, subword_tags
            else:
                nonmasked_tags = [[t for t in ts if t != "X"] for ts in tags]
                for swts, swids, swms, ts in zip(subword_tokens,
                                                 subword_tok_ids,
                                                 startofword_markers,
                                                 nonmasked_tags):
                    if (len(swids) != len(swms)) or (len(ts) != sum(swms)):
                        log.warning(
                            "Not matching lengths of the tokenization!")
                        log.warning(
                            f"Tokens len: {len(swts)}\n Tokens: {swts}")
                        log.warning(
                            f"Markers len: {len(swms)}, sum: {sum(swms)}")
                        log.warning(f"Masks: {swms}")
                        log.warning(f"Tags len: {len(ts)}\n Tags: {ts}")
                return tokens, subword_tokens, subword_tok_ids, attention_mask, startofword_markers, nonmasked_tags
        return tokens, subword_tokens, subword_tok_ids, startofword_markers, attention_mask, tokens_offsets_batch
    def __call__(self,
                 tokens_batch,
                 entity_offsets_batch,
                 mentions_batch=None,
                 pages_batch=None):
        token_ids_batch, attention_mask_batch, subw_tokens_batch, entity_subw_indices_batch = [], [], [], []
        if mentions_batch is None:
            mentions_batch = [[] for _ in tokens_batch]
        if pages_batch is None:
            pages_batch = [[] for _ in tokens_batch]

        for tokens, entity_offsets_list, mentions_list, pages_list in zip(
                tokens_batch, entity_offsets_batch, mentions_batch,
                pages_batch):
            tokens_list = []
            tokens_offsets_list = []
            for elem in re.finditer(self._re_tokenizer, tokens):
                tokens_list.append(elem[0])
                tokens_offsets_list.append((elem.start(), elem.end()))

            entity_indices_list = []
            for start_offset, end_offset in entity_offsets_list:
                entity_indices = []
                for ind, (start_tok_offset,
                          end_tok_offset) in enumerate(tokens_offsets_list):
                    if start_tok_offset >= start_offset and end_tok_offset <= end_offset:
                        entity_indices.append(ind)
                if not entity_indices:
                    for ind, (
                            start_tok_offset,
                            end_tok_offset) in enumerate(tokens_offsets_list):
                        if start_tok_offset >= start_offset:
                            entity_indices.append(ind)
                            break
                entity_indices_list.append(set(entity_indices))

            ind = 0
            subw_tokens_list = ["[CLS]"]
            entity_subw_indices_list = [[] for _ in entity_indices_list]
            for n, tok in enumerate(tokens_list):
                subw_tok = self.tokenizer.tokenize(tok)
                subw_tokens_list += subw_tok
                for j in range(len(entity_indices_list)):
                    if n in entity_indices_list[j]:
                        for k in range(len(subw_tok)):
                            entity_subw_indices_list[j].append(ind + k + 1)
                ind += len(subw_tok)
            subw_tokens_list.append("[SEP]")
            subw_tokens_batch.append(subw_tokens_list)

            for n in range(len(entity_subw_indices_list)):
                entity_subw_indices_list[n] = sorted(
                    entity_subw_indices_list[n])
            entity_subw_indices_batch.append(entity_subw_indices_list)

        token_ids_batch = [
            self.tokenizer.convert_tokens_to_ids(subw_tokens_list)
            for subw_tokens_list in subw_tokens_batch
        ]
        token_ids_batch = zero_pad(token_ids_batch, dtype=int, padding=0)
        attention_mask_batch = Mask()(subw_tokens_batch)

        return token_ids_batch, attention_mask_batch, entity_subw_indices_batch