def realign_answer_span(features: Features, answer_set: Optional[Set[Text]], processor: spm.SentencePieceProcessor, span: AnswerSpan) -> Optional[AnswerSpan]: """Align answer span to text with given tokens.""" i = bisect.bisect_left(features.token_offsets, span.begin) if i == len( features.token_offsets) or span.begin < features.token_offsets[i]: i -= 1 j = i + 1 answer_end = span.begin + len(span.text.encode('utf-8')) while (j < len(features.token_offsets) and features.token_offsets[j] < answer_end): j += 1 j -= 1 sp_answer = ( features.context[features.token_offsets[i]:features.token_offsets[j + 1]] if j + 1 < len(features.token_offsets) else features.context[features.token_offsets[i]:]) if (processor.IdToPiece(features.token_ids[i]).startswith('▁') and features.token_offsets[i] > 0): sp_answer = sp_answer[1:] sp_answer = evaluation.normalize_answer(sp_answer.decode('utf-8')) if answer_set is not None and sp_answer not in answer_set: # No need to warn if the cause was breaking word boundaries. if len(sp_answer) and not len(sp_answer) > len( evaluation.normalize_answer(span.text)): logging.warning('%s: "%s" not in %s.', features.question_id, sp_answer, answer_set) return None return AnswerSpan(begin=i, end=j, text=span.text)
class SentencePieceTokenizer: def __init__(self, spm_file, do_lower_case=True): self.processor = SentencePieceProcessor() self.processor.Load(spm_file) self.do_lower_case = do_lower_case def tokenize(self, text): text = preprocess_text(text, lower=self.do_lower_case) pieces = encode_pieces(self.processor, text, sample=False) return pieces def convert_tokens_to_ids(self, tokens): return [self.processor.PieceToId(piece) for piece in tokens] def convert_ids_to_tokens(self, ids): pieces = [self.processor.IdToPiece(_id) for _id in ids] return pieces
class SentencePieceTokenizer: def __init__(self, spm_file, do_lower_case=True): if not os.path.exists(spm_file): raise ValueError( "Can't find spm_file \"%s\". " "Please pass the correct path of sentence-piece model file, " "e.g.`spiece.model`." % spm_file ) self.processor = SentencePieceProcessor() self.processor.Load(spm_file) self.do_lower_case = do_lower_case def tokenize(self, text): text = preprocess_text(text, lower=self.do_lower_case) pieces = encode_pieces(self.processor, text, sample=False) return pieces def convert_tokens_to_ids(self, tokens): return [self.processor.PieceToId(piece) for piece in tokens] def convert_ids_to_tokens(self, ids): pieces = [self.processor.IdToPiece(_id) for _id in ids] return pieces
def beam_search_decode( model, X, sp: spm.SentencePieceProcessor, max_decode_len, k, per_node_k=None, constrain_decoding=False, sampler="deterministic", top_p_threshold=0.9, top_p_temperature=1.0, ): if sampler == "top_p": sampler = allennlp.nn.beam_search.TopPSampler( p=top_p_threshold, temperature=top_p_temperature) elif sampler == "deterministic": sampler = None else: raise ValueError("Unsupported sampler") # TODO: Implement constrained decoding (e.g. only alphanumeric) B = X.size(0) pad_id = sp.PieceToId("[PAD]") bos_id = sp.PieceToId("<s>") eos_id = sp.PieceToId("</s>") V_full = sp.GetPieceSize() # Size of vocab invalid_vocab_mask = torch.zeros(V_full, dtype=torch.bool, device=X.device) if constrain_decoding: alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890_ " for id in range(V_full): piece = sp.IdToPiece(id) if not (id in [pad_id, bos_id, eos_id] or all(c in alphabet for c in piece)): invalid_vocab_mask[id] = True V = V_full model.eval() # Encode X allen_bs = allennlp.nn.beam_search.BeamSearch( end_index=eos_id, max_steps=max_decode_len, beam_size=k, per_node_beam_size=per_node_k, sampler=sampler, ) start_predictions = torch.tensor([bos_id] * B, dtype=torch.long, device=X.device) start_state = { "prev_tokens": torch.zeros(B, 0, dtype=torch.long, device=X.device), "memory": model.encode(X).transpose(0, 1), # [B, T, d_model] } def step(last_tokens, current_state, t): """ Args: last_tokens: (group_size,) current_state: {} t: int """ group_size = last_tokens.size(0) prev_tokens = torch.cat( [current_state["prev_tokens"], last_tokens.unsqueeze(1)], dim=-1) # [B*k, t+1] all_log_probs = model.decode(current_state["memory"].transpose(0, 1), prev_tokens) next_log_probs = all_log_probs[:, -1, :] if constrain_decoding: next_log_probs = next_log_probs.masked_fill( invalid_vocab_mask, float("-inf")) next_log_probs = torch.nn.functional.log_softmax(next_log_probs, dim=-1) assert next_log_probs.shape == (group_size, V) return (next_log_probs, { "prev_tokens": prev_tokens, "memory": current_state["memory"] }) predictions, log_probs = allen_bs.search( start_predictions=start_predictions, start_state=start_state, step=step) model.train() prediction = ids_to_strs(predictions, sp) return prediction, log_probs
class SentencepieceFasttextEmbed(EmbedderInterface): class Config(EmbedderInterface.Config): pass @classmethod def from_config(cls, config: Config): spm_model_file = os.path.join(config.preproc_dir, "spm.model") fasttext_model_file = os.path.join(config.preproc_dir, "fasttext-model.bin") return cls(spm_model_file, fasttext_model_file, config.max_pieces) def __init__(self, spm_model_file: str, fasttext_model_file: str = '', max_pieces: int = -1): super().__init__(max_pieces=max_pieces) self.spm = SentencePieceProcessor() self.spm.Load(spm_model_file) self.pad_idx = self.spm.pad_id() self.pad_token = self.spm.IdToPiece(self.pad_idx) self.unk_idx = self.spm.unk_id() self.unk_token = self.spm.IdToPiece(self.unk_idx) self.bos_idx = self.spm.bos_id() self.bos_token = self.spm.IdToPiece(self.bos_idx) self.eos_idx = self.spm.eos_id() self.eos_token = self.spm.IdToPiece(self.eos_idx) if fasttext_model_file: self.fasttext = fasttext.load_model(fasttext_model_file) @property def embed_dim(self): return self.fasttext.dim @property def n_vocab(self): return self.spm.get_piece_size() def encode_text_as_ids(self, text: str) -> np.array: """ Doesn't produce BOS, EOS ids. """ return np.asarray(self.spm.EncodeAsIds(text)[self.pieces_slice], dtype=np.int32) def encode_text_as_tokens(self, text: str) -> List[str]: """ Doesn't produce BOS, EOS tokens. """ return self.spm.EncodeAsPieces(text)[self.pieces_slice] def tokenize(self, text: str) -> List[str]: """ Alias for `encode_text_as_tokens`. Doesn't produce BOS, EOS tokens. """ return self.encode_text_as_tokens(text)[self.pieces_slice] def decode_ids_as_text(self, ids: List[int], strip_special=True) -> str: """ Doesn't produce PAD, BOS, or EOS text. i.e. PAD, BOS, EOS ids are stripped out before decoding. UNK is decoded but unintelligible. """ if strip_special: ids = [ int(id) for id in ids if id not in (self.pad_idx, self.bos_idx, self.eos_idx) ] else: ids = [int(id) for id in ids] return self.spm.DecodeIds(ids) def decode_tokens_as_text(self, toks: List[str]) -> str: """ Doesn't produce PAD, BOS, or EOS text. i.e. PAD, BOS, EOS tokens are stripped out before decoding. UNK is decoded but unintelligible. """ return self.spm.DecodePieces(toks[self.pieces_slice]) @functools.lru_cache(maxsize=1024) def decode_id_as_token(self, id: int) -> str: return self.spm.IdToPiece(id) def decode_ids_as_tokens(self, ids: List[int], strip_special: bool = True) -> List[str]: """ By default, doesn't produce PAD, BOS, EOS tokens. Avoids problematic intermediate string representation that causes length mismatch. In other words, SentencePiece isn't isomorphic with respect to the string representation. """ if strip_special: ids = [ id for id in ids if id not in (self.pad_idx, self.bos_idx, self.eos_idx) ] return [self.decode_id_as_token(int(ix)) for ix in ids] @functools.lru_cache(maxsize=1024) def embed_tok(self, tok: str) -> np.array: """ When given PAD, returns all zeros """ if tok == self.pad_token: return np.zeros(self.fasttext.dim) return np.asarray(self.fasttext[tok]) def embed_text(self, text: str) -> np.array: """ Doesn't produce PAD, BOS, EOS embeddings. i.e. PAD, BOS, EOS are stripped out during tokenization before embedding. """ return np.asarray([self.embed_tok(tok) for tok in self.tokenize(text)]) def embed_ids(self, ids: List[int], strip_special: bool = True) -> List[np.array]: """ By default, doesn't produce PAD, BOS, EOS tokens. Avoids problematic intermediate string representation that causes length mismatch. In other words, SentencePiece isn't isomorphic with respect to the string representation. """ return [ self.embed_tok(t) for t in self.decode_ids_as_tokens(ids, strip_special=strip_special) ] def embed_ids_batch(self, ids: np.array) -> torch.tensor: emb = [self.embed_ids(turn, strip_special=False) for turn in ids] emb = torch.tensor(emb) return emb
class BPE_Dictionary(object): def __init__( self, dict, dict_type, pad=constants.PAD, eos=constants.EOS, unk=constants.UNK, bos=constants.BOS, ): self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos self.dict = os.path.expanduser(dict) self.dict_type = dict_type if self.dict_type == SENTENCEPIECE: assert self.exists(self.dict, self.dict_type) self.bpe_dict = SentencePieceProcessor() self.bpe_dict.load(f'{self.dict}.model') self.pad_index = self.bpe_dict.pad_id() self.bos_index = self.bpe_dict.bos_id() self.eos_index = self.bpe_dict.eos_id() self.unk_index = self.bpe_dict.unk_id() @staticmethod def exists(dict, dict_type='sentencepiece'): dict = os.path.expanduser(dict) if dict_type == SENTENCEPIECE: dict_file = f'{dict}.model' vocab_file = f'{dict}.vocab' if os.path.exists(dict_file) and os.path.exists(vocab_file): return True else: return False else: raise NotImplementedError def save(self, dict_name): dict_name = os.path.expanduser(dict_name) os.makedirs(os.path.dirname(dict_name), exist_ok=True) if self.dict_type == SENTENCEPIECE: shutil.copy(f'{self.dict}.model', f'{dict_name}.model') shutil.copy(f'{self.dict}.vocab', f'{dict_name}.vocab') else: raise NotImplementedError def encode_tokens(self, sentence): return self.bpe_dict.EncodeAsPieces(sentence) def encode_ids(self, sentence): return self.bpe_dict.EncodeAsIds(sentence) def string(self, tensor: torch.Tensor, bpe_symbol=None, escape_unk=None, trunc_eos=None): if torch.is_tensor(tensor) and tensor.dim() == 2: return "\n".join( self.string(t, bpe_symbol, escape_unk, trunc_eos) for t in tensor) return self.bpe_dict.Decode(tensor.tolist()) def __getitem__(self, idx): return self.bpe_dict.IdToPiece(idx) def __contains__(self, sym): return self.index(sym) != self.unk() def index(self, sym): return self.bpe_dict[sym] def __len__(self): return len(self.bpe_dict) def bos(self): """Helper to get index of beginning-of-sentence symbol""" return self.bos_index def pad(self): """Helper to get index of pad symbol""" return self.pad_index def eos(self): """Helper to get index of end-of-sentence symbol""" return self.eos_index def unk(self): """Helper to get index of unk symbol""" return self.unk_index