def __init__(
        self,
        examples: List[QAExample],
        tokenizer: SentencePieceBPETokenizer,
        max_sequence_length: int,
        is_train: bool = True,
    ) -> None:
        self.examples = examples
        self.tokenizer = tokenizer
        self.max_sequence_length = max_sequence_length

        self.sos_token = tokenizer.token_to_id("<s>")
        self.eos_token = tokenizer.token_to_id("</s>")
        self.question_prefix_tokens = self.tokenizer.encode("질문:").ids

        self.is_train = is_train
Esempio n. 2
0
 def load_tokenizer(path,
                    enable_truncation=True,
                    enable_padding=True,
                    max_length=512):
     tokenizer = SentencePieceBPETokenizer(os.path.join(path, "vocab.json"),
                                           os.path.join(path, "merges.txt"))
     tokenizer._tokenizer.post_processor = BertProcessing(
         ("</s>", tokenizer.token_to_id("</s>")),
         ("<s>", tokenizer.token_to_id("<s>")),
     )
     if enable_truncation:
         tokenizer.enable_truncation(max_length=max_length)
     if enable_padding:
         tokenizer.enable_padding(pad_token="<pad>",
                                  pad_id=tokenizer.token_to_id("<pad>"))
     return tokenizer
class TokenizerWrapper:
    def __init__(self, tok_type, unk_token, sep_token, cls_token, pad_token,
                 mask_token):
        self.tok_type = tok_type

        if self.tok_type == 'bpe':
            self.tokenizer = ByteLevelBPETokenizer()
        elif self.tok_type == 'wordpiece':
            self.tokenizer = BertWordPieceTokenizer(unk_token=unk_token,
                                                    sep_token=sep_token,
                                                    cls_token=cls_token,
                                                    pad_token=pad_token,
                                                    mask_token=mask_token)
        elif self.tok_type == 'sentencepiece':
            self.tokenizer = SentencePieceBPETokenizer(unk_token=unk_token)

    def train(self, data_file, vocab_size, special_tokens):
        if self.tok_type in ['bpe', 'wordpiece', 'sentencepiece']:
            self.tokenizer.train([data_file],
                                 vocab_size=vocab_size,
                                 special_tokens=special_tokens)

    def tokenize(self, text):
        if self.tok_type in ['bpe', 'wordpiece', 'sentencepiece']:
            return self.tokenizer.encode(text).tokens
        elif self.tok_type == 'word':
            return nltk.tokenize.word_tokenize(text)
        elif self.tok_type == 'char':
            return [c for c in text]
        else:
            raise Exception('Unknown tokenizer: ' + self.tok_type)

    def decode(self, tokens, blank_token):
        if self.tok_type in ['bpe', 'wordpiece', 'sentencepiece']:
            ids = [self.tokenizer.token_to_id(t) for t in tokens]
            ids = [
                i if i != None else self.tokenizer.token_to_id(blank_token)
                for i in ids
            ]
            return self.tokenizer.decode(ids, skip_special_tokens=False)
        elif self.tok_type == 'word':
            return ' '.join(tokens)
        elif self.tok_type == 'char':
            return ''.join(tokens)
        else:
            raise Exception('Unknown tokenizer: ' + self.tok_type)
Esempio n. 4
0
class DecodeBySentencePieceBPETokenizer(Preprocessor):
    __provider__ = 'decode_by_sentence_piece_bpe_tokenizer'

    @classmethod
    def parameters(cls):
        parameters = super().parameters()
        parameters.update({
            'vocabulary_file':
            PathField(),
            'merges_file':
            PathField(),
            'sos_symbol':
            StringField(optional=True, default='<s>'),
            'eos_symbol':
            StringField(optional=True, default='</s>'),
            'add_extra_symbols':
            BoolField(optional=True, default=True),
        })

        return parameters

    def configure(self):
        if isinstance(SentencePieceBPETokenizer, UnsupportedPackage):
            SentencePieceBPETokenizer.raise_error(self.__provider__)
        self.tokenizer = SentencePieceBPETokenizer(
            str(self.get_value_from_config('vocabulary_file')),
            str(self.get_value_from_config('merges_file')))
        self.add_extra_symbols = self.get_value_from_config(
            'add_extra_symbols')
        self.idx = {}
        for s in ['sos', 'eos']:
            self.idx[s] = self.tokenizer.token_to_id(
                str(self.get_value_from_config(s + '_symbol')))

    def process(self, image, annotation_meta=None):
        sentence = " ".join(image.data)
        tokens = self.tokenizer.encode(sentence).ids
        if self.add_extra_symbols:
            tokens = [self.idx['sos']] + tokens + [self.idx['eos']]
        image.data = tokens
        image.metadata['decoded'] = True
        image.identifier = "tokens"

        return image
class Tokenizer:
    """ Sentence tokenizer.

    Arguments:
        path (str): path to tokenizer's model folder.
        max_tokens (int): max tokens.
    """
    def __init__(self, path, max_tokens):
        self.logger = log.getLogger("Tokenizer")
        self.logger.info("loading tokenizer")
        self.logger.info("path: " + path)
        self.logger.info("max_tokens: " + str(max_tokens))
        self.tokenizer = SentencePieceBPETokenizer(
            os.path.join(path, "vocab.json"),
            os.path.join(path, "merges.txt")
        )
        self.max_tokens = max_tokens
        self.idx = {}
        for s in ['</s>', '<s>', '<pad>']:
            self.idx[s] = self.tokenizer.token_to_id(s)

    def encode(self, sentence):
        """ Encode method for sentence.

        Arguments:
            sentence (str): sentence.

        Returns:
            tokens (np.array): encoded sentence in tokneized format.
        """
        tokens = self.tokenizer.encode(sentence).ids
        return self._extend_tokens(tokens)

    def decode(self, tokens, remove_repeats=True):
        """ Decode method for tokens.

        Arguments:
            tokens (np.array): sentence in tokenized format.
            remove_repeats (bool): remove repeated words.

        Returns:
            sentence (str): output sentence.
        """
        sentence = self.tokenizer.decode(tokens)
        for s in self.idx.keys():
            sentence = sentence.replace(s, '')
        if remove_repeats:
            sentence = self._remove_repeats(sentence)
        return sentence.lstrip()

    def _extend_tokens(self, tokens):
        """ Extend tokens.

        Arguments:
            tokens (np.array): sentence in tokenized format.

        Returns:
            tokens (np.array): extended tokens.
        """
        tokens = [self.idx['<s>']] + tokens + [self.idx['</s>']]
        pad_length = self.max_tokens - len(tokens)
        if pad_length > 0:
            tokens = tokens + [self.idx['<pad>']] * pad_length
        return tokens

    def _remove_repeats(self, sentence):
        """ Remove repeated words.

        Arguments:
            sentence (str): sentence.

        Returns:
            sentence (str): sentence in lowercase without repeated words.
        """
        tokens = sentence.lower().split()
        return " ".join(key for key, _ in itertools.groupby(tokens))
class BPEVocabulary(Vocabulary):
    """ Represents a SentencePiece vocabulary for c2s.
    """
    def __init__(self, args: Namespace):
        super().__init__()

        self.target_encoder = SentencePieceBPETokenizer(
            args.target_vocab, args.target_merges)
        self.subtoken_encoder = SentencePieceBPETokenizer(
            args.subtoken_vocab, args.subtoken_merges)
        # self.target_encoder.add_special_tokens(
        #     [self.EOS_TOKEN, self.SOS_TOKEN, self.PAD_TOKEN]
        # )
        # self.subtoken_encoder.add_special_tokens([self.EOS_TOKEN, self.PAD_TOKEN])

        with open(args.node_dict, "rb") as f:
            self.node_to_index = pickle.load(f)
            self.index_to_node = {v: k for k, v in self.node_to_index.items()}

    def target_vocab_size(self):
        # print(self.target_encoder.num_special_tokens_to_add())
        return self.target_encoder.get_vocab_size() + 4

    def node_vocab_size(self):
        # print(self.target_encoder.num_special_tokens_to_add())
        return len(self.node_to_index) + 2

    def terminal_vocab_size(self):
        return self.subtoken_encoder.get_vocab_size() + 4

    def add_special_target_token(self, token: str):
        self.target_encoder.add_special_tokens([token])

    def add_special_terminal_token(self, token: str):
        self.subtoken_encoder.add_special_tokens([token])

    def encode_node(self, token_or_tokens):
        if isinstance(token_or_tokens, str):
            return self.node_to_index.get(token_or_tokens,
                                          self.node_to_index[self.UNK_TOKEN])
        else:
            return list(map(self.encode_node, token_or_tokens))

    def decode_node(self, index_or_indices):
        if isinstance(index_or_indices, int):
            return self.index_to_node[index_or_indices]
        else:
            return list(map(self.decode_node, index_or_indices))

    def encode_target(self, token_or_tokens):
        if isinstance(token_or_tokens, str):
            return self.target_encoder.token_to_id(token_or_tokens)
        else:
            return self.target_encoder.encode(" ".join(token_or_tokens)).ids

    def decode_target(self, index_or_indices):
        if isinstance(index_or_indices, int):
            return self.target_encoder.id_to_token(index_or_indices)
        else:
            return self.target_encoder.decode(index_or_indices)

    def encode_terminal(self, token_or_tokens):
        if isinstance(token_or_tokens, str):
            return self.subtoken_encoder.token_to_id(token_or_tokens)
        else:
            return self.subtoken_encoder.encode(" ".join(token_or_tokens)).ids

    def decode_terminal(self, index_or_indices):
        if isinstance(index_or_indices, int):
            return self.terminal_encoder.id_to_token(index_or_indices)
        else:
            return self.terminal_encoder.decode(index_or_indices)
class TextProcessor:
    def __init__(self, tok_model_path: Optional[str] = None):
        self.languages = {}
        if tok_model_path is not None:
            self.tokenizer = SentencePieceBPETokenizer(
                tok_model_path + "/vocab.json",
                tok_model_path + "/merges.txt",
            )
            with open(os.path.join(tok_model_path, "langs"), "rb") as fp:
                self.languages: Dict[str, int] = pickle.load(fp)
        self.init_properties(self.languages)

    def init_properties(self, languages: Dict[str, int] = None):
        self.max_len = 512
        self.pad_token = "<pad>"
        self.mask_token = "<mask>"
        self.unk_token = "<unk>"
        self.sep_token = "</s>"
        self.bos = "<s>"
        self.special_tokens = [
            self.pad_token, self.bos, self.unk_token, self.mask_token,
            self.sep_token
        ] + list(languages.keys())
        self.languages = languages

    def train_tokenizer(self, paths: List[str], vocab_size: int,
                        to_save_dir: str, languages: Dict[str, int]):
        self.tokenizer = SentencePieceBPETokenizer()
        self.init_properties(languages)
        self.tokenizer.train(files=paths,
                             vocab_size=vocab_size,
                             min_frequency=5,
                             special_tokens=self.special_tokens)
        self.save(directory=to_save_dir)

    def _tokenize(self, line) -> Encoding:
        return self.tokenizer.encode(line)

    def save(self, directory):
        self.tokenizer.save(directory)
        with open(os.path.join(directory, "langs"), "wb") as fp:
            pickle.dump(self.languages, fp)

    def tokenize_one_line(self,
                          line,
                          ignore_middle_eos: bool = False) -> List[int]:
        tokenized = []
        spl = [sen for sen in line.split("</s>") if len(sen.strip()) > 0]
        if spl[0].startswith("<"):
            words = spl[0].strip().split(" ")
            spl[0] = " ".join(words[1:])
            tokenized += [self.token_id(words[0])]

        for sen in spl:
            tokenized += self._tokenize(sen).ids
            if not ignore_middle_eos:
                tokenized += [self.sep_token_id()]
        if ignore_middle_eos:
            tokenized += [self.sep_token_id()]
        return tokenized

    def tokenize_one_sentence(self, line) -> List[int]:
        """
        Assume that the sentence has language id in the first token and end of sentence as the end!
        :param line:
        :return:
        """
        spl = line.strip().split(" ")
        lang_id, sen, eos = spl[0], " ".join(spl[1:-1]), spl[-1]
        tokenized = [self.token_id(lang_id)
                     ] + self._tokenize(sen).ids + [self.token_id(eos)]
        return tokenized

    def tokenize_lines(self,
                       line,
                       blind_split: bool = False,
                       split_len: int = 512) -> List[List[int]]:
        """

        :param line:
        :param blind_split: If True, just splits the tokenized data into chunks without considering that every vector
        should start with a first word in sentence.
        :return:
        """
        tokenized = []
        if len(self.languages) > 0:
            spl = [sen for sen in line.split("</s>") if len(sen.strip()) > 0]
            lang_id = []
            if spl[0].startswith("<"):
                words = spl[0].strip().split(" ")
                lang_id = [self.token_id(words[0])]
                spl[0] = " ".join(words[1:])

            max_len = 0
            for sen in spl:
                toks = self._tokenize(sen).ids
                tokenized += lang_id + toks + [self.sep_token_id()]
                max_len = max(max_len, len(toks) + 1)
        else:
            tokenized = self._tokenize(line.strip()).ids

        if blind_split:
            num_pads = (split_len - (len(tokenized) % split_len))
            pad_arr = [self.pad_token_id()] * num_pads
            tokenized = np.array(tokenized + pad_arr)
            reshaped = tokenized.reshape((-1, split_len))
            return reshaped
        else:
            return self.split_tokenized(tokenized, min(max_len, self.max_len))

    def tokenize(self, lines) -> List[List[int]]:
        lines = [
            line.strip() for line in lines.strip().split("\n")
            if len(line.strip()) > 0
        ]
        tokenized = self.tokenizer.encode_batch(lines)
        return [tok.ids for tok in tokenized]

    def pad_token_id(self) -> int:
        return self.tokenizer.token_to_id(self.pad_token)

    def mask_token_id(self) -> int:
        return self.tokenizer.token_to_id(self.mask_token)

    def unk_token_id(self) -> int:
        return self.tokenizer.token_to_id(self.unk_token)

    def bos_token_id(self) -> int:
        return self.tokenizer.token_to_id(self.bos)

    def sep_token_id(self) -> int:
        return self.tokenizer.token_to_id(self.sep_token)

    def token_id(self, token: str) -> int:
        tok_id = self.tokenizer.token_to_id(token)
        if tok_id is None:
            return 0
        return tok_id

    def id2token(self, id: int) -> str:
        return self.tokenizer.id_to_token(id)

    def vocab_size(self) -> int:
        return self.tokenizer.get_vocab_size()

    def is_lang(self, id) -> bool:
        return self.tokenizer.id_to_token(id) in self.languages

    def lang_id(self, tok):
        if tok in self.languages:
            return self.languages[tok]
        return 0

    def split_tokenized(self,
                        tokenized: List[int],
                        max_length: int = 512) -> List[List[int]]:
        """
        Based on self.max_len, splits very long sequences to smaller ones.
        Here we assume to not have any overlapping sequences.
        If the first token is a language, we add it to every new sequence.
        :return:
        """
        if len(tokenized) <= max_length:
            sequences = [tokenized]
            sequences[-1] = sequences[-1] + (
                max_length - len(sequences[-1])) * [self.pad_token_id()]
            return sequences

        has_lang = self.is_lang(tokenized[0])
        sequence = tokenized[0:] if has_lang else tokenized

        seq_len = len(sequence)
        sep_id = self.sep_token_id()
        max_len = max_length - 1 if has_lang else max_length

        cur_start = 0
        sequences = []
        built_seq = []
        truncated = False  # Shows if previous sequence is truncated due to its length.
        used_ends = set()
        while cur_start < seq_len:
            if not truncated or not has_lang:
                cur_end = min(seq_len, cur_start + max_len)
            else:
                cur_end = min(seq_len, cur_start + max_len + 1)
            subseq = sequence[cur_start:cur_end]

            built_seq += subseq
            sep_positions = [
                i for i, id in enumerate(built_seq) if id == sep_id
            ]
            if len(sep_positions) > 0:
                if sep_positions[-1] in used_ends:
                    truncated = True
                else:
                    built_seq = built_seq[:sep_positions[-1] + 1]
                    truncated = False
            else:
                truncated = True

            assert built_seq[-1] == sequence[len(built_seq) - 1]

            if has_lang and len(subseq) < max_len + 1:
                subseq = [tokenized[0]] + subseq

            sequences.append(subseq)

            cur_start = len(built_seq)
            used_ends.add(cur_start - 1)
        if len(sequences[-1]) < max_length:
            sequences[-1] = sequences[-1] + (
                max_length - len(sequences[-1])) * [self.pad_token_id()]
        assert built_seq[-1] == sequence[len(built_seq) - 1]
        return sequences