def test_train_from_iterator(self):
        text = ["A first sentence", "Another sentence", "And a last one"]
        tokenizer = SentencePieceBPETokenizer()
        tokenizer.train_from_iterator(text, show_progress=False)

        output = tokenizer.encode("A sentence")
        assert output.tokens == ["▁A", "▁sentence"]
 def train_tokenizer(self, paths: List[str], vocab_size: int,
                     to_save_dir: str, languages: Dict[str, int]):
     self.tokenizer = SentencePieceBPETokenizer()
     self.init_properties(languages)
     self.tokenizer.train(files=paths,
                          vocab_size=vocab_size,
                          min_frequency=5,
                          special_tokens=self.special_tokens)
     self.save(directory=to_save_dir)
 def __init__(self, dataset_folder, tokenizer_method):
     self.dataset_folder = dataset_folder
     self.tokenizer_method = tokenizer_method
     if tokenizer_method == "sentencepiece":
         self.tokenizer = SentencePieceBPETokenizer(
             "./data/sentencepiece_tokenizer/vocab.json",
             "./data/sentencepiece_tokenizer/merges.txt")
     elif tokenizer_method == "bert":
         self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
 def __init__(self, tok_model_path: Optional[str] = None):
     self.languages = {}
     if tok_model_path is not None:
         self.tokenizer = SentencePieceBPETokenizer(
             tok_model_path + "/vocab.json",
             tok_model_path + "/merges.txt",
         )
         with open(os.path.join(tok_model_path, "langs"), "rb") as fp:
             self.languages: Dict[str, int] = pickle.load(fp)
     self.init_properties(self.languages)
Beispiel #5
0
 def __init__(self, path, max_tokens):
     self.logger = log.getLogger("Tokenizer")
     self.logger.info("loading tokenizer")
     self.logger.info("path: " + path)
     self.logger.info("max_tokens: " + str(max_tokens))
     self.tokenizer = SentencePieceBPETokenizer(
         os.path.join(path, "vocab.json"), os.path.join(path, "merges.txt"))
     self.max_tokens = max_tokens
     self.idx = {}
     for s in ['</s>', '<s>', '<pad>']:
         self.idx[s] = self.tokenizer.token_to_id(s)
Beispiel #6
0
 def fit_on_text(self, text):
     if self.lower:
         text = text.lower()
     words = text.split()
     tokenizer1 = SentencePieceBPETokenizer(vocab, merges)
     for word in words:
         for sub_word in tokenizer1.encode(word).tokens:
             if sub_word not in self.word2idx:
                 self.word2idx[sub_word] = self.idx
                 self.idx2word[self.idx] = sub_word
                 self.idx += 1
Beispiel #7
0
 def configure(self):
     if isinstance(SentencePieceBPETokenizer, UnsupportedPackage):
         SentencePieceBPETokenizer.raise_error(self.__provider__)
     self.tokenizer = SentencePieceBPETokenizer(
         str(self.get_value_from_config('vocabulary_file')),
         str(self.get_value_from_config('merges_file')))
     self.add_extra_symbols = self.get_value_from_config(
         'add_extra_symbols')
     self.idx = {}
     for s in ['sos', 'eos']:
         self.idx[s] = self.tokenizer.token_to_id(
             str(self.get_value_from_config(s + '_symbol')))
Beispiel #8
0
 def __init__(self, path, max_tokens):
     self.logger = log.getLogger("Tokenizer")
     self.logger.info("loading tokenizer")
     self.logger.info(f"path: {path}")
     self.logger.info(f"max_tokens: {max_tokens}")
     self.tokenizer = SentencePieceBPETokenizer(
         str(path / "vocab.json"),
         str(path / "merges.txt"),
     )
     self.max_tokens = max_tokens
     self.idx = {}
     for s in ['</s>', '<s>', '<pad>']:
         self.idx[s] = self.tokenizer.token_to_id(s)
Beispiel #9
0
 def configure(self):
     if isinstance(SentencePieceBPETokenizer, UnsupportedPackage):
         SentencePieceBPETokenizer.raise_error(self.__provider__)
     self.tokenizer = SentencePieceBPETokenizer(
         str(self.get_value_from_config('vocabulary_file')),
         str(self.get_value_from_config('merges_file')))
     self.remove_extra_symbols = self.get_value_from_config(
         'remove_extra_symbols')
     self.idx = {}
     for s in ['sos', 'eos', 'pad']:
         self.idx[s] = str(self.get_value_from_config(s + '_symbol'))
     self.output_name = self.get_value_from_config('output_name')
     self.output_checked = False
    def __init__(self, tok_type, unk_token, sep_token, cls_token, pad_token,
                 mask_token):
        self.tok_type = tok_type

        if self.tok_type == 'bpe':
            self.tokenizer = ByteLevelBPETokenizer()
        elif self.tok_type == 'wordpiece':
            self.tokenizer = BertWordPieceTokenizer(unk_token=unk_token,
                                                    sep_token=sep_token,
                                                    cls_token=cls_token,
                                                    pad_token=pad_token,
                                                    mask_token=mask_token)
        elif self.tok_type == 'sentencepiece':
            self.tokenizer = SentencePieceBPETokenizer(unk_token=unk_token)
Beispiel #11
0
    def set_tokenizer(self):

        if self.storage_method == "raw":
            pass  # Essentially keep it None. Important for exceptions
        elif self.storage_method == "bert":
            self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        elif self.storage_method == "roberta":
            self.tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
        elif self.storage_method == "token":
            self.tokenizer = SentencePieceBPETokenizer(
                os.path.join(self.tokenizer_path, "/vocab.json"),
                os.path.join(self.tokenizer_path, "merges.txt"))
        else:
            raise ValueError("Unknown storage method encountered!")
Beispiel #12
0
    def __init__(self, args):
        self.args = args
        if self.args.type == "byte":
            self.tokenizer = ByteLevelBPETokenizer(
                add_prefix_space=True,  # required
                lowercase=True,  # required
                unicode_normalizer=None,  # required
                vocab_file=None,
                merges_file=None,
                dropout=None,
                continuing_subword_prefix=None,
                end_of_word_suffix=None)

        elif self.args.type == "char":
            self.tokenizer = CharBPETokenizer(
                unk_token=unk_token,  # required
                suffix=suffix_token,  # required
                lowercase=True,  # required
                unicode_normalizer=None,  # required
                vocab_file=None,
                merges_file=None,
                dropout=None)

        elif self.args.type == "bert":
            self.tokenizer = BertWordPieceTokenizer(
                clean_text=True,  # required
                handle_chinese_chars=True,  # required
                strip_accents=True,  # required
                lowercase=True,  # required
                vocab_file=None,
                # add_special_tokens=True,
                unk_token=BUNK,
                sep_token=BSEP,
                cls_token=BCLS,
                wordpieces_prefix=BPRE)

        elif self.args.type == "sent":
            self.tokenizer = SentencePieceBPETokenizer(
                add_prefix_space=True,  # required
                unk_token=unk_token,
                replacement=rep_token,
                vocab_file=None,
                merges_file=None,
                dropout=None)

        else:
            raise Exception("Not implement yet")

        pass
    def __init__(self, args: Namespace):
        super().__init__()

        self.target_encoder = SentencePieceBPETokenizer(
            args.target_vocab, args.target_merges)
        self.subtoken_encoder = SentencePieceBPETokenizer(
            args.subtoken_vocab, args.subtoken_merges)
        # self.target_encoder.add_special_tokens(
        #     [self.EOS_TOKEN, self.SOS_TOKEN, self.PAD_TOKEN]
        # )
        # self.subtoken_encoder.add_special_tokens([self.EOS_TOKEN, self.PAD_TOKEN])

        with open(args.node_dict, "rb") as f:
            self.node_to_index = pickle.load(f)
            self.index_to_node = {v: k for k, v in self.node_to_index.items()}
Beispiel #14
0
    def load(vocab_file=None):
        if not os.path.exists(vocab_file):
            raise Exception("{} is not exist".format(vocab_file))
        path, filename = os.path.split(vocab_file)
        ttype = filename.split("_")[0]
        merges_file = os.path.join(
            path, filename.replace("vocab.json", "merges.txt"))
        if ttype == "byte":
            if not os.path.exists(merges_file):
                raise Exception("{} is not exist".format(merges_file))
            tokenizer = ByteLevelBPETokenizer(
                add_prefix_space=True,  # required
                lowercase=True,  # required
                unicode_normalizer=None,  # required
                vocab_file=vocab_file,
                merges_file=merges_file,
                dropout=None,
                continuing_subword_prefix=None,
                end_of_word_suffix=None)

        elif ttype == "char":
            if not os.path.exists(merges_file):
                raise Exception("{} is not exist".format(merges_file))
            tokenizer = CharBPETokenizer(
                unk_token=unk_token,  # required
                suffix=suffix_token,  # required
                lowercase=True,  # required
                unicode_normalizer=None,  # required
                vocab_file=vocab_file,
                merges_file=merges_file,
                dropout=None)

        elif ttype == "bert":
            tokenizer = BertWordPieceTokenizer(
                clean_text=True,  # required
                handle_chinese_chars=True,  # required
                strip_accents=True,  # required
                lowercase=True,  # required
                vocab_file=vocab_file,
                # add_special_tokens=True,
                unk_token=BUNK,
                sep_token=BSEP,
                cls_token=BCLS,
                wordpieces_prefix=BPRE)

        elif ttype == "sent":
            if not os.path.exists(merges_file):
                raise Exception("{} is not exist".format(merges_file))
            tokenizer = SentencePieceBPETokenizer(
                add_prefix_space=True,  # required
                unk_token=unk_token,
                replacement=rep_token,
                vocab_file=vocab_file,
                merges_file=merges_file,
                dropout=None)

        else:
            raise Exception("Not implement yet")

        return tokenizer
    def __init__(
        self,
        examples: List[QAExample],
        tokenizer: SentencePieceBPETokenizer,
        max_sequence_length: int,
        is_train: bool = True,
    ) -> None:
        self.examples = examples
        self.tokenizer = tokenizer
        self.max_sequence_length = max_sequence_length

        self.sos_token = tokenizer.token_to_id("<s>")
        self.eos_token = tokenizer.token_to_id("</s>")
        self.question_prefix_tokens = self.tokenizer.encode("질문:").ids

        self.is_train = is_train
def main():
    config = QGConfig()
    args = parser.parse_args()

    model = GPT2LMHeadModel.from_pretrained("taeminlee/kogpt2")
    model.load_state_dict(torch.load(args.model_path, map_location="cpu"))
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    tokenizer = SentencePieceBPETokenizer.from_file(
        vocab_filename="tokenizer/vocab.json", merges_filename="tokenizer/merges.txt", add_prefix_space=False
    )
    examples = load_korquad_dataset(config.dev_dataset)
    dataset = QGDataset(examples, tokenizer, config.max_sequence_length)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=dynamic_padding_collate_fn)

    model = model.to(device)
    model.eval()

    model.eval()
    loss_list = []
    for batch_data in tqdm(dataloader, desc="[EVAL]"):
        with torch.no_grad():
            input_ids, attention_mask, labels = tuple(value.to(device) for value in batch_data)
            model_outputs = model.forward(input_ids, attention_mask=attention_mask, labels=labels, return_dict=True)
            loss_list.append(model_outputs.loss.item())

    mean_loss = np.mean(loss_list)
    print(f"loss:{mean_loss:.4f} perplexity:{math.exp(mean_loss):.4f}")
    model.train()
Beispiel #17
0
def main():
    config = QGConfig()
    args = parser.parse_args()

    model = GPT2LMHeadModel.from_pretrained("taeminlee/kogpt2")
    model.load_state_dict(torch.load(args.model_path, map_location="cpu"))
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    tokenizer = SentencePieceBPETokenizer.from_file(
        vocab_filename="tokenizer/vocab.json",
        merges_filename="tokenizer/merges.txt",
        add_prefix_space=False)
    examples = load_korquad_dataset(config.dev_dataset)
    random.shuffle(examples)
    examples = examples[:args.num_samples]
    dataset = QGDecodingDataset(examples, tokenizer,
                                config.max_sequence_length)
    dataloader = DataLoader(dataset, batch_size=1)

    model = model.to(device)
    model.eval()

    generated_results = []

    for i, batch in tqdm(enumerate(dataloader),
                         desc="generate",
                         total=len(dataloader)):
        input_ids, attention_mask = (v.to(device) for v in batch)
        origin_seq_len = input_ids.size(-1)

        decoded_sequences = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=origin_seq_len + MAX_QUESTION_SPACE,
            min_length=origin_seq_len + MIN_QUESTION_SPACE,
            pad_token_id=0,
            bos_token_id=1,
            eos_token_id=2,
            num_beams=args.num_beams,
            repetition_penalty=1.3,
            no_repeat_ngram_size=3,
            num_return_sequences=1,
        )

        for decoded_tokens in decoded_sequences.tolist():
            decoded_question_text = tokenizer.decode(
                decoded_tokens[origin_seq_len:])
            decoded_question_text = decoded_question_text.split(
                "</s>")[0].replace("<s>", "")
            generated_results.append(
                (examples[i].context, examples[i].answer, examples[i].question,
                 decoded_question_text))

    with open(args.output_path, "w") as f:
        for context, answer, question, generated_question in generated_results:
            f.write(f"문맥\t{context}\n")
            f.write(f"답변\t{answer}\n")
            f.write(f"생성된 질문\t{generated_question}\n")
            f.write(f"실제 질문\t{question}\n\n")
 def __init__(self, path, max_tokens):
     self.tokenizer = SentencePieceBPETokenizer.from_file(
         str(path / "vocab.json"),
         str(path / "merges.txt"),
     )
     self.max_tokens = max_tokens
     self.idx = {}
     for s in ['</s>', '<s>', '<pad>']:
         self.idx[s] = self.tokenizer.token_to_id(s)
Beispiel #19
0
class DecodeBySentencePieceBPETokenizer(Preprocessor):
    __provider__ = 'decode_by_sentence_piece_bpe_tokenizer'

    @classmethod
    def parameters(cls):
        parameters = super().parameters()
        parameters.update({
            'vocabulary_file':
            PathField(),
            'merges_file':
            PathField(),
            'sos_symbol':
            StringField(optional=True, default='<s>'),
            'eos_symbol':
            StringField(optional=True, default='</s>'),
            'add_extra_symbols':
            BoolField(optional=True, default=True),
        })

        return parameters

    def configure(self):
        if isinstance(SentencePieceBPETokenizer, UnsupportedPackage):
            SentencePieceBPETokenizer.raise_error(self.__provider__)
        self.tokenizer = SentencePieceBPETokenizer(
            str(self.get_value_from_config('vocabulary_file')),
            str(self.get_value_from_config('merges_file')))
        self.add_extra_symbols = self.get_value_from_config(
            'add_extra_symbols')
        self.idx = {}
        for s in ['sos', 'eos']:
            self.idx[s] = self.tokenizer.token_to_id(
                str(self.get_value_from_config(s + '_symbol')))

    def process(self, image, annotation_meta=None):
        sentence = " ".join(image.data)
        tokens = self.tokenizer.encode(sentence).ids
        if self.add_extra_symbols:
            tokens = [self.idx['sos']] + tokens + [self.idx['eos']]
        image.data = tokens
        image.metadata['decoded'] = True
        image.identifier = "tokens"

        return image
def main():
    #argparser
    parser = argparse.ArgumentParser(
        prog="train_mlm_camembert_thai.py",
        description="train mlm for Camembert with huggingface Trainer",
    )

    #required
    parser.add_argument("--bpe_tokenizer",
                        type=str,
                        default='sentencepiece',
                        help='Specify the name of BPE Tokenizer')
    parser.add_argument("--vocab_size", type=int, default=52000)
    parser.add_argument("--min_frequency", type=int, default=2)
    parser.add_argument(
        "--train_dir",
        type=str,
    )
    parser.add_argument(
        "--output_dir",
        type=str,
    )
    parser.add_argument("--ext", type=str, default='.txt')

    args = parser.parse_args()

    fnames = [str(x) for x in glob.glob(f"{args.train_dir}/*{args.ext}")]

    # Initialize a tokenizer
    if args.bpe_tokenizer == 'byte_level':
        _BPE_TOKENIZER = ByteLevelBPETokenizer()
    if args.bpe_tokenizer == 'char':
        _BPE_TOKENIZER = CharBPETokenizer()
    if args.bpe_tokenizer == 'sentencepiece':
        _BPE_TOKENIZER = SentencePieceBPETokenizer()

    tokenizer = _BPE_TOKENIZER

    # Customize training
    tokenizer.train(files=fnames,
                    vocab_size=args.vocab_size,
                    min_frequency=args.min_frequency,
                    special_tokens=[
                        "<s>",
                        "<pad>",
                        "</s>",
                        "<unk>",
                        "<mask>",
                    ])

    # Save files to disk
    tokenizer.save_model(args.output_dir)

    #test
    tokenizer = CamembertTokenizer.from_pretrained(args.output_dir)
    print(tokenizer.encode_plus('สวัสดีครับ hello world'))
def get_default_tokenizer():
    from tokenizers import SentencePieceBPETokenizer

    tokenizer = SentencePieceBPETokenizer(path.join(VOCAB_PATH,
                                                    'en-vocab.json'),
                                          path.join(VOCAB_PATH,
                                                    'en-merges.txt'),
                                          unk_token='[UNK]')

    return tokenizer
    def __init__(self, max_meta_len: int, max_body_len: int,
                 ignore_meta_prob: float):
        tokenizer = SentencePieceBPETokenizer(vocab_file=str(_VOCAB),
                                              merges_file=str(_MERGES))

        super().__init__(tokenizer=tokenizer,
                         max_meta_len=max_meta_len,
                         max_body_len=max_body_len,
                         ignore_meta_prob=ignore_meta_prob,
                         pad_token='<pad>')
Beispiel #23
0
 def load_tokenizer(path,
                    enable_truncation=True,
                    enable_padding=True,
                    max_length=512):
     tokenizer = SentencePieceBPETokenizer(os.path.join(path, "vocab.json"),
                                           os.path.join(path, "merges.txt"))
     tokenizer._tokenizer.post_processor = BertProcessing(
         ("</s>", tokenizer.token_to_id("</s>")),
         ("<s>", tokenizer.token_to_id("<s>")),
     )
     if enable_truncation:
         tokenizer.enable_truncation(max_length=max_length)
     if enable_padding:
         tokenizer.enable_padding(pad_token="<pad>",
                                  pad_id=tokenizer.token_to_id("<pad>"))
     return tokenizer
Beispiel #24
0
class NonAutoregressiveMachineTranslationAdapter(Adapter):
    __provider__ = 'narnmt'

    @classmethod
    def parameters(cls):
        parameters = super().parameters()
        parameters.update({
            'vocabulary_file':
            PathField(),
            'merges_file':
            PathField(),
            'output_name':
            StringField(optional=True, default=None),
            'sos_symbol':
            StringField(optional=True, default='<s>'),
            'eos_symbol':
            StringField(optional=True, default='</s>'),
            'pad_symbol':
            StringField(optional=True, default='<pad>'),
            'remove_extra_symbols':
            BoolField(optional=True, default=True)
        })
        return parameters

    def configure(self):
        if isinstance(SentencePieceBPETokenizer, UnsupportedPackage):
            SentencePieceBPETokenizer.raise_error(self.__provider__)
        self.tokenizer = SentencePieceBPETokenizer(
            str(self.get_value_from_config('vocabulary_file')),
            str(self.get_value_from_config('merges_file')))
        self.remove_extra_symbols = self.get_value_from_config(
            'remove_extra_symbols')
        self.idx = {}
        for s in ['sos', 'eos', 'pad']:
            self.idx[s] = str(self.get_value_from_config(s + '_symbol'))
        self.output_name = self.get_value_from_config('output_name')
        if self.output_name is None:
            self.output_name = self.output_blob

    def process(self, raw, identifiers, frame_meta):
        raw_outputs = self._extract_predictions(raw, frame_meta)
        translation = raw_outputs[self.output_name]
        results = []
        for identifier, tokens in zip(identifiers, translation):
            sentence = self.tokenizer.decode(tokens)
            if self.remove_extra_symbols:
                for s in self.idx.values():
                    sentence = sentence.replace(s, '')
            results.append(
                MachineTranslationPrediction(identifier,
                                             sentence.lstrip().split(' ')))
        return results
def load_tokenizer(langpair: str) -> SentencePieceBPETokenizer:
    if langpair in ["en-de", "de-en", "ende", "deen", "ENDE", "EN-DE"]:
        langpair = "deen"

    tokenizer_dir = Path(__file__).parent.parent / "src" / "tokenizer"
    vocab_filepath = (
        tokenizer_dir / f"sentencepiece_bpe_wmt14_{langpair}.tokenizer-vocab.json"
    )
    merges_filepath = (
        tokenizer_dir / f"sentencepiece_bpe_wmt14_{langpair}.tokenizer-merges.txt"
    )

    tokenizer = SentencePieceBPETokenizer(
        vocab_file=str(vocab_filepath),
        merges_file=str(merges_filepath),
    )
    return tokenizer
Beispiel #26
0
 def train(corpus_list, vocab_size, output, output_name=None):
     print("create tokenizer...")
     tokenizer = SentencePieceBPETokenizer()
     print("load corpus list...")
     corpus_list = open(corpus_list).read().split('\n')[:-1]
     print("train tokenizer...")
     tokenizer.train(
         corpus_list,
         vocab_size=vocab_size,
         special_tokens=["<s>", "<pad>", "</s>", "<unk>", "<mask>"])
     print("save model...")
     tokenizer.save_model(output, output_name)
Beispiel #27
0
def main():
    args = cmd_args()
    outdir = args.o if args.o else os.path.dirname(args.i)

    print(
        f"Training SentencePiece to create a vocabulary of size {args.vocab_size}"
    )
    with tempfile.TemporaryDirectory() as tmp_dir:
        train_file = os.path.join(tmp_dir, "train.txt")
        create_bpe_training_file(args.i, train_file)

        tokenizer = SentencePieceBPETokenizer()
        tokenizer.train(files=[train_file], vocab_size=args.vocab_size)

    tokenizer.save(outdir, args.n)
Beispiel #28
0
def build_bpe(vocab_size=10000):
    # Initialize a tokenizer
    tokenizer = SentencePieceBPETokenizer()

    #mypath = "../../Downloads/riksdagens_protokoll_1920-2020/annual"
    mypath = "../../Desktop/cood/python/machine-learning/old-school/markov-lstm-killer/data/fi"
    onlyfiles = [f for f in listdir(mypath) if isfile(join(mypath, f))]
    print("ONL", onlyfiles)

    paths = [mypath + "/" + f for f in onlyfiles]

    #paths = paths[:5]

    # COPY FILES
    txts = []
    for path, fname in zip(paths, onlyfiles):
        if path[-4:] == ".txt":
            localpath = "data/" + fname
            txts.append(localpath)

            infile = open(path)
            outfile = open(localpath, "w")

            for line in infile:
                clean_line = cleanup(line) + "\n"
                outfile.write(clean_line)

            outfile.close()

    # Then train it!
    #tokenizer.train([ "../../Downloads/riksdagens_protokoll_1920-2020/annual/prot_2019.txt" ], vocab_size=15000)
    tokenizer.train(txts, vocab_size=vocab_size)

    # Now, let's use it:
    s = "Det politiska arbetet har redan börjat på olika sätt, med resor, besök, möten, politikutveckling, motionsskrivande och mycket annat. Jag har sett att ni redan har varit aktiva under ett antal veckor, och jag kan försäkra er att det även gäller talmanspresidiet. Nu är det dags att med tillförsikt påbörja ett nytt riksdagsår. Jag hoppas att ni alla ser fram emot det lika myck­et som jag gör."
    #s = "Ite en oo viel mitää hyvää kyl sielt syöny."
    #s = "ja kieltämät siihe tommoste kokonaisii sanoi merkitsevät tavumerkit on huomattavasti näppärämpii ku ääniä tarkottavat aakkoset joist pitää rakentaa jokane sana"
    encoded = tokenizer.encode(s)

    print(encoded.ids)
    print(encoded.tokens)
    # And finally save it somewhere
    tokenizer.save("./bpe-fi.tokenizer.json")
 def __init__(self,
              vocab_file,
              merges_file,
              unk_token="<unk>",
              bos_token="<s>",
              eos_token="</s>",
              pad_token="<pad>",
              add_prefix_space=False,
              **kwargs):
     super().__init__(
         SentencePieceBPETokenizer(
             vocab_file=vocab_file,
             merges_file=merges_file,
             add_prefix_space=add_prefix_space,
         ),
         bos_token=bos_token,
         eos_token=eos_token,
         unk_token=unk_token,
         pad_token=pad_token,
         **kwargs,
     )
class TokenizerWrapper:
    def __init__(self, tok_type, unk_token, sep_token, cls_token, pad_token,
                 mask_token):
        self.tok_type = tok_type

        if self.tok_type == 'bpe':
            self.tokenizer = ByteLevelBPETokenizer()
        elif self.tok_type == 'wordpiece':
            self.tokenizer = BertWordPieceTokenizer(unk_token=unk_token,
                                                    sep_token=sep_token,
                                                    cls_token=cls_token,
                                                    pad_token=pad_token,
                                                    mask_token=mask_token)
        elif self.tok_type == 'sentencepiece':
            self.tokenizer = SentencePieceBPETokenizer(unk_token=unk_token)

    def train(self, data_file, vocab_size, special_tokens):
        if self.tok_type in ['bpe', 'wordpiece', 'sentencepiece']:
            self.tokenizer.train([data_file],
                                 vocab_size=vocab_size,
                                 special_tokens=special_tokens)

    def tokenize(self, text):
        if self.tok_type in ['bpe', 'wordpiece', 'sentencepiece']:
            return self.tokenizer.encode(text).tokens
        elif self.tok_type == 'word':
            return nltk.tokenize.word_tokenize(text)
        elif self.tok_type == 'char':
            return [c for c in text]
        else:
            raise Exception('Unknown tokenizer: ' + self.tok_type)

    def decode(self, tokens, blank_token):
        if self.tok_type in ['bpe', 'wordpiece', 'sentencepiece']:
            ids = [self.tokenizer.token_to_id(t) for t in tokens]
            ids = [
                i if i != None else self.tokenizer.token_to_id(blank_token)
                for i in ids
            ]
            return self.tokenizer.decode(ids, skip_special_tokens=False)
        elif self.tok_type == 'word':
            return ' '.join(tokens)
        elif self.tok_type == 'char':
            return ''.join(tokens)
        else:
            raise Exception('Unknown tokenizer: ' + self.tok_type)