def __call__(self, tokens: Union[List[List[str]], List[str]], tags: List[List[str]] = None, **kwargs): if isinstance(tokens[0], str): tokens = [re.findall(self._re_tokenizer, s) for s in tokens] subword_tokens, subword_tok_ids, startofword_markers, subword_tags = [], [], [], [] for i in range(len(tokens)): toks = tokens[i] ys = ['O'] * len(toks) if tags is None else tags[i] assert len(toks) == len(ys), \ f"toks({len(toks)}) should have the same length as ys({len(ys)})" sw_toks, sw_marker, sw_ys = \ self._ner_bert_tokenize(toks, ys, self.tokenizer, self.max_subword_length, mode=self.mode, subword_mask_mode=self.subword_mask_mode, token_masking_prob=self.token_masking_prob) if self.max_seq_length is not None: if len(sw_toks) > self.max_seq_length: raise RuntimeError(f"input sequence after bert tokenization" f" shouldn't exceed {self.max_seq_length} tokens.") subword_tokens.append(sw_toks) subword_tok_ids.append(self.tokenizer.convert_tokens_to_ids(sw_toks)) startofword_markers.append(sw_marker) subword_tags.append(sw_ys) assert len(sw_marker) == len(sw_toks) == len(subword_tok_ids[-1]) == len(sw_ys), \ f"length of sow_marker({len(sw_marker)}), tokens({len(sw_toks)})," \ f" token ids({len(subword_tok_ids[-1])}) and ys({len(ys)})" \ f" for tokens = `{toks}` should match" subword_tok_ids = zero_pad(subword_tok_ids, dtype=int, padding=0) startofword_markers = zero_pad(startofword_markers, dtype=int, padding=0) attention_mask = Mask()(subword_tokens) if tags is not None: if self.provide_subword_tags: return tokens, subword_tokens, subword_tok_ids, \ attention_mask, startofword_markers, subword_tags else: nonmasked_tags = [[t for t in ts if t != 'X'] for ts in tags] for swts, swids, swms, ts in zip(subword_tokens, subword_tok_ids, startofword_markers, nonmasked_tags): if (len(swids) != len(swms)) or (len(ts) != sum(swms)): log.warning('Not matching lengths of the tokenization!') log.warning(f'Tokens len: {len(swts)}\n Tokens: {swts}') log.warning(f'Markers len: {len(swms)}, sum: {sum(swms)}') log.warning(f'Masks: {swms}') log.warning(f'Tags len: {len(ts)}\n Tags: {ts}') return tokens, subword_tokens, subword_tok_ids, \ attention_mask, startofword_markers, nonmasked_tags return tokens, subword_tokens, subword_tok_ids, startofword_markers, attention_mask
def __call__(self, tokens_batch, entity_offsets_batch, mentions_batch=None, pages_batch=None): token_ids_batch, attention_mask_batch, subw_tokens_batch, entity_subw_indices_batch = [], [], [], [] if mentions_batch is None: mentions_batch = [[] for _ in tokens_batch] if pages_batch is None: pages_batch = [[] for _ in tokens_batch] for tokens, entity_offsets_list, mentions_list, pages_list in zip( tokens_batch, entity_offsets_batch, mentions_batch, pages_batch): tokens_list = [] tokens_offsets_list = [] for elem in re.finditer(self._re_tokenizer, tokens): tokens_list.append(elem[0]) tokens_offsets_list.append((elem.start(), elem.end())) entity_indices_list = [] for start_offset, end_offset in entity_offsets_list: entity_indices = [] for ind, (start_tok_offset, end_tok_offset) in enumerate(tokens_offsets_list): if start_tok_offset >= start_offset and end_tok_offset <= end_offset: entity_indices.append(ind) if not entity_indices: for ind, ( start_tok_offset, end_tok_offset) in enumerate(tokens_offsets_list): if start_tok_offset >= start_offset: entity_indices.append(ind) break entity_indices_list.append(set(entity_indices)) ind = 0 subw_tokens_list = ["[CLS]"] entity_subw_indices_list = [[] for _ in entity_indices_list] for n, tok in enumerate(tokens_list): subw_tok = self.tokenizer.tokenize(tok) subw_tokens_list += subw_tok for j in range(len(entity_indices_list)): if n in entity_indices_list[j]: for k in range(len(subw_tok)): entity_subw_indices_list[j].append(ind + k + 1) ind += len(subw_tok) subw_tokens_list.append("[SEP]") subw_tokens_batch.append(subw_tokens_list) for n in range(len(entity_subw_indices_list)): entity_subw_indices_list[n] = sorted( entity_subw_indices_list[n]) entity_subw_indices_batch.append(entity_subw_indices_list) token_ids_batch = [ self.tokenizer.convert_tokens_to_ids(subw_tokens_list) for subw_tokens_list in subw_tokens_batch ] token_ids_batch = zero_pad(token_ids_batch, dtype=int, padding=0) attention_mask_batch = Mask()(subw_tokens_batch) return token_ids_batch, attention_mask_batch, entity_subw_indices_batch