Exemplo n.º 1
0
def realign_answer_span(features: Features, answer_set: Optional[Set[Text]],
                        processor: spm.SentencePieceProcessor,
                        span: AnswerSpan) -> Optional[AnswerSpan]:
    """Align answer span to text with given tokens."""
    i = bisect.bisect_left(features.token_offsets, span.begin)
    if i == len(
            features.token_offsets) or span.begin < features.token_offsets[i]:
        i -= 1
    j = i + 1
    answer_end = span.begin + len(span.text.encode('utf-8'))
    while (j < len(features.token_offsets)
           and features.token_offsets[j] < answer_end):
        j += 1
    j -= 1
    sp_answer = (
        features.context[features.token_offsets[i]:features.token_offsets[j +
                                                                          1]]
        if j + 1 < len(features.token_offsets) else
        features.context[features.token_offsets[i]:])
    if (processor.IdToPiece(features.token_ids[i]).startswith('▁')
            and features.token_offsets[i] > 0):
        sp_answer = sp_answer[1:]
    sp_answer = evaluation.normalize_answer(sp_answer.decode('utf-8'))
    if answer_set is not None and sp_answer not in answer_set:
        # No need to warn if the cause was breaking word boundaries.
        if len(sp_answer) and not len(sp_answer) > len(
                evaluation.normalize_answer(span.text)):
            logging.warning('%s: "%s" not in %s.', features.question_id,
                            sp_answer, answer_set)
        return None
    return AnswerSpan(begin=i, end=j, text=span.text)
Exemplo n.º 2
0
class SentencePieceTokenizer:
    def __init__(self, spm_file, do_lower_case=True):
        self.processor = SentencePieceProcessor()
        self.processor.Load(spm_file)
        self.do_lower_case = do_lower_case

    def tokenize(self, text):
        text = preprocess_text(text, lower=self.do_lower_case)
        pieces = encode_pieces(self.processor, text, sample=False)
        return pieces

    def convert_tokens_to_ids(self, tokens):
        return [self.processor.PieceToId(piece) for piece in tokens]

    def convert_ids_to_tokens(self, ids):
        pieces = [self.processor.IdToPiece(_id) for _id in ids]
        return pieces
Exemplo n.º 3
0
class SentencePieceTokenizer:
    def __init__(self, spm_file, do_lower_case=True):
        if not os.path.exists(spm_file):
            raise ValueError(
                "Can't find spm_file \"%s\". "
                "Please pass the correct path of sentence-piece model file, "
                "e.g.`spiece.model`." % spm_file
            )
        self.processor = SentencePieceProcessor()
        self.processor.Load(spm_file)
        self.do_lower_case = do_lower_case

    def tokenize(self, text):
        text = preprocess_text(text, lower=self.do_lower_case)
        pieces = encode_pieces(self.processor, text, sample=False)
        return pieces

    def convert_tokens_to_ids(self, tokens):
        return [self.processor.PieceToId(piece) for piece in tokens]

    def convert_ids_to_tokens(self, ids):
        pieces = [self.processor.IdToPiece(_id) for _id in ids]
        return pieces
Exemplo n.º 4
0
def beam_search_decode(
    model,
    X,
    sp: spm.SentencePieceProcessor,
    max_decode_len,
    k,
    per_node_k=None,
    constrain_decoding=False,
    sampler="deterministic",
    top_p_threshold=0.9,
    top_p_temperature=1.0,
):
    if sampler == "top_p":
        sampler = allennlp.nn.beam_search.TopPSampler(
            p=top_p_threshold, temperature=top_p_temperature)
    elif sampler == "deterministic":
        sampler = None
    else:
        raise ValueError("Unsupported sampler")

    # TODO: Implement constrained decoding (e.g. only alphanumeric)
    B = X.size(0)
    pad_id = sp.PieceToId("[PAD]")
    bos_id = sp.PieceToId("<s>")
    eos_id = sp.PieceToId("</s>")
    V_full = sp.GetPieceSize()  # Size of vocab
    invalid_vocab_mask = torch.zeros(V_full, dtype=torch.bool, device=X.device)
    if constrain_decoding:
        alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890_ "
        for id in range(V_full):
            piece = sp.IdToPiece(id)
            if not (id in [pad_id, bos_id, eos_id] or all(c in alphabet
                                                          for c in piece)):
                invalid_vocab_mask[id] = True
    V = V_full
    model.eval()

    # Encode X
    allen_bs = allennlp.nn.beam_search.BeamSearch(
        end_index=eos_id,
        max_steps=max_decode_len,
        beam_size=k,
        per_node_beam_size=per_node_k,
        sampler=sampler,
    )

    start_predictions = torch.tensor([bos_id] * B,
                                     dtype=torch.long,
                                     device=X.device)
    start_state = {
        "prev_tokens": torch.zeros(B, 0, dtype=torch.long, device=X.device),
        "memory": model.encode(X).transpose(0, 1),  # [B, T, d_model]
    }

    def step(last_tokens, current_state, t):
        """
        Args:
            last_tokens: (group_size,)
            current_state: {}
            t: int
        """
        group_size = last_tokens.size(0)
        prev_tokens = torch.cat(
            [current_state["prev_tokens"],
             last_tokens.unsqueeze(1)], dim=-1)  # [B*k, t+1]

        all_log_probs = model.decode(current_state["memory"].transpose(0, 1),
                                     prev_tokens)
        next_log_probs = all_log_probs[:, -1, :]
        if constrain_decoding:
            next_log_probs = next_log_probs.masked_fill(
                invalid_vocab_mask, float("-inf"))
        next_log_probs = torch.nn.functional.log_softmax(next_log_probs,
                                                         dim=-1)
        assert next_log_probs.shape == (group_size, V)
        return (next_log_probs, {
            "prev_tokens": prev_tokens,
            "memory": current_state["memory"]
        })

    predictions, log_probs = allen_bs.search(
        start_predictions=start_predictions,
        start_state=start_state,
        step=step)

    model.train()
    prediction = ids_to_strs(predictions, sp)
    return prediction, log_probs
class SentencepieceFasttextEmbed(EmbedderInterface):
    class Config(EmbedderInterface.Config):
        pass

    @classmethod
    def from_config(cls, config: Config):
        spm_model_file = os.path.join(config.preproc_dir, "spm.model")
        fasttext_model_file = os.path.join(config.preproc_dir,
                                           "fasttext-model.bin")
        return cls(spm_model_file, fasttext_model_file, config.max_pieces)

    def __init__(self,
                 spm_model_file: str,
                 fasttext_model_file: str = '',
                 max_pieces: int = -1):
        super().__init__(max_pieces=max_pieces)

        self.spm = SentencePieceProcessor()
        self.spm.Load(spm_model_file)
        self.pad_idx = self.spm.pad_id()
        self.pad_token = self.spm.IdToPiece(self.pad_idx)
        self.unk_idx = self.spm.unk_id()
        self.unk_token = self.spm.IdToPiece(self.unk_idx)
        self.bos_idx = self.spm.bos_id()
        self.bos_token = self.spm.IdToPiece(self.bos_idx)
        self.eos_idx = self.spm.eos_id()
        self.eos_token = self.spm.IdToPiece(self.eos_idx)

        if fasttext_model_file:
            self.fasttext = fasttext.load_model(fasttext_model_file)

    @property
    def embed_dim(self):
        return self.fasttext.dim

    @property
    def n_vocab(self):
        return self.spm.get_piece_size()

    def encode_text_as_ids(self, text: str) -> np.array:
        """
    Doesn't produce BOS, EOS ids.
    """
        return np.asarray(self.spm.EncodeAsIds(text)[self.pieces_slice],
                          dtype=np.int32)

    def encode_text_as_tokens(self, text: str) -> List[str]:
        """
    Doesn't produce BOS, EOS tokens.
    """
        return self.spm.EncodeAsPieces(text)[self.pieces_slice]

    def tokenize(self, text: str) -> List[str]:
        """
    Alias for `encode_text_as_tokens`.
    Doesn't produce BOS, EOS tokens.
    """
        return self.encode_text_as_tokens(text)[self.pieces_slice]

    def decode_ids_as_text(self, ids: List[int], strip_special=True) -> str:
        """
    Doesn't produce PAD, BOS, or EOS text.
    i.e. PAD, BOS, EOS ids are stripped out before decoding.
    UNK is decoded but unintelligible.
    """
        if strip_special:
            ids = [
                int(id) for id in ids
                if id not in (self.pad_idx, self.bos_idx, self.eos_idx)
            ]
        else:
            ids = [int(id) for id in ids]
        return self.spm.DecodeIds(ids)

    def decode_tokens_as_text(self, toks: List[str]) -> str:
        """
    Doesn't produce PAD, BOS, or EOS text.
    i.e. PAD, BOS, EOS tokens are stripped out before decoding.
    UNK is decoded but unintelligible.
    """
        return self.spm.DecodePieces(toks[self.pieces_slice])

    @functools.lru_cache(maxsize=1024)
    def decode_id_as_token(self, id: int) -> str:
        return self.spm.IdToPiece(id)

    def decode_ids_as_tokens(self,
                             ids: List[int],
                             strip_special: bool = True) -> List[str]:
        """
    By default, doesn't produce PAD, BOS, EOS tokens.

    Avoids problematic intermediate string representation that causes length mismatch.
    In other words, SentencePiece isn't isomorphic with respect to the string representation.
    """
        if strip_special:
            ids = [
                id for id in ids
                if id not in (self.pad_idx, self.bos_idx, self.eos_idx)
            ]
        return [self.decode_id_as_token(int(ix)) for ix in ids]

    @functools.lru_cache(maxsize=1024)
    def embed_tok(self, tok: str) -> np.array:
        """
    When given PAD, returns all zeros
    """
        if tok == self.pad_token:
            return np.zeros(self.fasttext.dim)
        return np.asarray(self.fasttext[tok])

    def embed_text(self, text: str) -> np.array:
        """
    Doesn't produce PAD, BOS, EOS embeddings.
    i.e. PAD, BOS, EOS are stripped out during tokenization before embedding.
    """
        return np.asarray([self.embed_tok(tok) for tok in self.tokenize(text)])

    def embed_ids(self,
                  ids: List[int],
                  strip_special: bool = True) -> List[np.array]:
        """
    By default, doesn't produce PAD, BOS, EOS tokens.

    Avoids problematic intermediate string representation that causes length mismatch.
    In other words, SentencePiece isn't isomorphic with respect to the string representation.
    """
        return [
            self.embed_tok(t)
            for t in self.decode_ids_as_tokens(ids,
                                               strip_special=strip_special)
        ]

    def embed_ids_batch(self, ids: np.array) -> torch.tensor:
        emb = [self.embed_ids(turn, strip_special=False) for turn in ids]
        emb = torch.tensor(emb)
        return emb
Exemplo n.º 6
0
class BPE_Dictionary(object):
    def __init__(
        self,
        dict,
        dict_type,
        pad=constants.PAD,
        eos=constants.EOS,
        unk=constants.UNK,
        bos=constants.BOS,
    ):
        self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
        self.dict = os.path.expanduser(dict)
        self.dict_type = dict_type

        if self.dict_type == SENTENCEPIECE:
            assert self.exists(self.dict, self.dict_type)
            self.bpe_dict = SentencePieceProcessor()
            self.bpe_dict.load(f'{self.dict}.model')
            self.pad_index = self.bpe_dict.pad_id()
            self.bos_index = self.bpe_dict.bos_id()
            self.eos_index = self.bpe_dict.eos_id()
            self.unk_index = self.bpe_dict.unk_id()

    @staticmethod
    def exists(dict, dict_type='sentencepiece'):
        dict = os.path.expanduser(dict)
        if dict_type == SENTENCEPIECE:
            dict_file = f'{dict}.model'
            vocab_file = f'{dict}.vocab'
            if os.path.exists(dict_file) and os.path.exists(vocab_file):
                return True
            else:
                return False
        else:
            raise NotImplementedError

    def save(self, dict_name):
        dict_name = os.path.expanduser(dict_name)
        os.makedirs(os.path.dirname(dict_name), exist_ok=True)
        if self.dict_type == SENTENCEPIECE:
            shutil.copy(f'{self.dict}.model', f'{dict_name}.model')
            shutil.copy(f'{self.dict}.vocab', f'{dict_name}.vocab')
        else:
            raise NotImplementedError

    def encode_tokens(self, sentence):
        return self.bpe_dict.EncodeAsPieces(sentence)

    def encode_ids(self, sentence):
        return self.bpe_dict.EncodeAsIds(sentence)

    def string(self,
               tensor: torch.Tensor,
               bpe_symbol=None,
               escape_unk=None,
               trunc_eos=None):
        if torch.is_tensor(tensor) and tensor.dim() == 2:
            return "\n".join(
                self.string(t, bpe_symbol, escape_unk, trunc_eos)
                for t in tensor)
        return self.bpe_dict.Decode(tensor.tolist())

    def __getitem__(self, idx):
        return self.bpe_dict.IdToPiece(idx)

    def __contains__(self, sym):
        return self.index(sym) != self.unk()

    def index(self, sym):
        return self.bpe_dict[sym]

    def __len__(self):
        return len(self.bpe_dict)

    def bos(self):
        """Helper to get index of beginning-of-sentence symbol"""
        return self.bos_index

    def pad(self):
        """Helper to get index of pad symbol"""
        return self.pad_index

    def eos(self):
        """Helper to get index of end-of-sentence symbol"""
        return self.eos_index

    def unk(self):
        """Helper to get index of unk symbol"""
        return self.unk_index