def main(train_path, val_path, test_path, config_path, subword_model_path,
         out_dir):
    params = Params.from_file(config_path)
    reader_params = params.pop("reader", default=Params({}))
    reader = DatasetReader.from_params(reader_params)
    processor = SentencePieceProcessor()
    processor.Load(subword_model_path)
    train_text_file = os.path.join(out_dir, "train.text.txt")
    train_summary_file = os.path.join(out_dir, "train.summary.txt")
    val_text_file = os.path.join(out_dir, "val.text.txt")
    val_summary_file = os.path.join(out_dir, "val.summary.txt")
    test_text_file = os.path.join(out_dir, "test.text.txt")
    test_summary_file = os.path.join(out_dir, "test.summary.txt")
    files = ((train_path, train_text_file,
              train_summary_file), (val_path, val_text_file, val_summary_file),
             (test_path, test_text_file, test_summary_file))
    for path, text_file_name, summary_file_name in files:
        with open(text_file_name,
                  "w") as text_file, open(summary_file_name,
                                          "w") as summary_file:
            for text, summary in reader.parse_set(path):
                text_subwords = processor.EncodeAsPieces(text)
                summary_subwords = processor.EncodeAsPieces(summary)
                text_subwords.insert(0, "<t>")
                text_subwords.append("</t>")
                summary_subwords.insert(0, "<t>")
                summary_subwords.append("</t>")
                text_file.write(" ".join(text_subwords) + "\n")
                summary_file.write((" ".join(summary_subwords)) + "\n")
def main(train_path,
         val_path,
         test_path,
         config_path,
         subword_model_path,
         out_dir,
         max_text_subwords,
         max_summary_subwords,
         source_suffix,
         target_suffix,
         insert_tags=False,
         lowercase=False):
    params = Params.from_file(config_path)
    reader_params = params.pop("dataset_reader", default=Params({}))
    reader = DatasetReader.from_params(reader_params)

    processor = SentencePieceProcessor()
    processor.Load(subword_model_path)

    train_text_file = os.path.join(out_dir, "train.{}".format(source_suffix))
    train_summary_file = os.path.join(out_dir,
                                      "train.{}".format(target_suffix))
    val_text_file = os.path.join(out_dir, "val.{}".format(source_suffix))
    val_summary_file = os.path.join(out_dir, "val.{}".format(target_suffix))
    test_text_file = os.path.join(out_dir, "test.{}".format(source_suffix))
    test_summary_file = os.path.join(out_dir, "test.{}".format(target_suffix))

    files = ((train_path, train_text_file,
              train_summary_file), (val_path, val_text_file, val_summary_file),
             (test_path, test_text_file, test_summary_file))
    for path, text_file_name, summary_file_name in files:
        with open(text_file_name,
                  "w") as text_file, open(summary_file_name,
                                          "w") as summary_file:
            for text, summary in reader.parse_set(path):
                if lowercase:
                    text = text.lower()
                    summary = summary.lower()
                text_subwords = processor.EncodeAsPieces(text)
                if max_text_subwords:
                    text_subwords = text_subwords[:max_text_subwords]
                summary_subwords = processor.EncodeAsPieces(summary)
                if max_summary_subwords:
                    summary_subwords = summary_subwords[:max_summary_subwords]
                if insert_tags:
                    text_subwords.insert(0, "<t>")
                    text_subwords.append("</t>")
                    summary_subwords.insert(0, "<t>")
                    summary_subwords.append("</t>")
                text_file.write(" ".join(text_subwords) + "\n")
                summary_file.write((" ".join(summary_subwords)) + "\n")
Exemple #3
0
def main(train_path,
         val_path,
         test_path,
         mode,
         subword_model_path,
         output_dir,
         max_source_subwords,
         max_target_subwords,
         source_suffix,
         target_suffix,
         lowercase=False):
    processor = SentencePieceProcessor()
    processor.Load(subword_model_path)

    os.makedirs(output_dir, exist_ok=True)
    train_source_file = os.path.join(output_dir,
                                     "train.{}".format(source_suffix))
    train_target_file = os.path.join(output_dir,
                                     "train.{}".format(target_suffix))
    val_source_file = os.path.join(output_dir, "val.{}".format(source_suffix))
    val_target_file = os.path.join(output_dir, "val.{}".format(target_suffix))
    test_source_file = os.path.join(output_dir,
                                    "test.{}".format(source_suffix))
    test_target_file = os.path.join(output_dir,
                                    "test.{}".format(target_suffix))

    parse = MODES.get(mode, None)
    assert parse is not None

    files = ((train_path, train_source_file,
              train_target_file), (val_path, val_source_file, val_target_file),
             (test_path, test_source_file, test_target_file))
    for path, source_file_name, target_file_name in files:
        with open(source_file_name,
                  "w") as source_file, open(target_file_name,
                                            "w") as target_file:
            for record in parse(path):
                source = record["source"]
                target = record["target"]
                if lowercase:
                    source = source.lower()
                    target = target.lower()
                source_subwords = processor.EncodeAsPieces(source)
                if max_source_subwords:
                    source_subwords = source_subwords[:max_source_subwords]
                target_subwords = processor.EncodeAsPieces(target)
                if max_target_subwords:
                    target_subwords = target_subwords[:max_target_subwords]
                source_file.write(" ".join(source_subwords) + "\n")
                target_file.write((" ".join(target_subwords)) + "\n")
def sentence2vector(sentence, model_sentence_piece: spm.SentencePieceProcessor,
                    dict_token2vector: dict,
                    dict_is_valid: dict) -> [np.array, bool]:
    """
    文をベクトルに変換
    戻り値の第2引数は、意味を獲得できたベクトルが1つもないときで分岐
    - True: 意味を獲得できたtokenがあった
    - False: なかった
    :param sentence: str, 対象の文 
    :param model_sentence_piece: spm.SentencePieceProcessor, sentencepieceモデル
    :param dict_token2vector: dict, token から拡張単語分散表現を得る辞書
    :param dict_is_valid: dict, token からモデルに含まれるかTrueで返す
    :return: 
    """
    tokens_raw = model_sentence_piece.EncodeAsPieces(sentence)
    tokens = leave_valid(tokens=tokens_raw, dict_is_valid=dict_is_valid)
    vector_size = len(list(dict_token2vector.values())[0])
    # 意味を獲得できたtokenが1つもないときは、原点に集めるため、zerosで初期化
    vector = np.zeros((vector_size, max(len(tokens), 1)), dtype=np.float64)
    for i, token in enumerate(tokens):
        vector[:, i] = dict_token2vector[token]

    # いちいちifで分岐しなくてもvector自体は正しく返せるけど、エラーメッセージ表示用に分ける
    # ## exceptionを発行して止めるかどうかが悩ましい
    if len(tokens) != 0:
        # valid token exists
        return [vector.mean(axis=1), len(tokens) != 0]
    else:
        # no valid token
        # assert len(tokens) != 0, "no valid token, change phrase or word"
        print(
            f" this sentence has no valid token, change phrase or word\n{''.join(tokens_raw[:20]).replace('_', '')}"
        )
        return [vector.mean(axis=1), len(tokens) != 0]
Exemple #5
0
class SentencePieceTokenizer():#TODO: pass the special tokens symbol to sp
    "SentencePiece tokenizer for `lang`"
    def __init__(self, lang='en', special_toks=None, sp_model=None, vocab_sz=None, max_vocab_sz=30000,
                 model_type='unigram', char_coverage=None, cache_dir='tmp'):
        try: from sentencepiece import SentencePieceTrainer,SentencePieceProcessor
        except ImportError:
            raise Exception('sentencepiece module is missing: run `pip install sentencepiece!=0.1.90,!=0.1.91`')
        self.sp_model,self.cache_dir = sp_model,Path(cache_dir)
        self.vocab_sz,self.max_vocab_sz,self.model_type = vocab_sz,max_vocab_sz,model_type
        self.char_coverage = ifnone(char_coverage, 0.99999 if lang in eu_langs else 0.9998)
        self.special_toks = ifnone(special_toks, defaults.text_spec_tok)
        if sp_model is None: self.tok = None
        else:
            self.tok = SentencePieceProcessor()
            self.tok.Load(str(sp_model))
        os.makedirs(self.cache_dir, exist_ok=True)

    def _get_vocab_sz(self, raw_text_path):
        cnt = Counter()
        with open(raw_text_path, 'r') as f:
            for line in f.readlines():
                cnt.update(line.split())
                if len(cnt)//4 > self.max_vocab_sz: return self.max_vocab_sz
        res = len(cnt)//4
        while res%8 != 0: res+=1
        return max(res,29)

    def train(self, raw_text_path):
        "Train a sentencepiece tokenizer on `texts` and save it in `path/tmp_dir`"
        from sentencepiece import SentencePieceTrainer
        vocab_sz = self._get_vocab_sz(raw_text_path) if self.vocab_sz is None else self.vocab_sz
        spec_tokens = ['\u2581'+s for s in self.special_toks]
        SentencePieceTrainer.Train(" ".join([
            f"--input={raw_text_path} --vocab_size={vocab_sz} --model_prefix={self.cache_dir/'spm'}",
            f"--character_coverage={self.char_coverage} --model_type={self.model_type}",
            f"--unk_id={len(spec_tokens)} --pad_id=-1 --bos_id=-1 --eos_id=-1 --minloglevel=2",
            f"--user_defined_symbols={','.join(spec_tokens)} --hard_vocab_limit=false"]))
        raw_text_path.unlink()
        return self.cache_dir/'spm.model'

    def setup(self, items, rules=None):
        from sentencepiece import SentencePieceProcessor
        if rules is None: rules = []
        if self.tok is not None: return {'sp_model': self.sp_model}
        raw_text_path = self.cache_dir/'texts.out'
        with open(raw_text_path, 'w') as f:
            for t in progress_bar(maps(*rules, items), total=len(items), leave=False):
                f.write(f'{t}\n')
        sp_model = self.train(raw_text_path)
        self.tok = SentencePieceProcessor()
        self.tok.Load(str(sp_model))
        return {'sp_model': sp_model}

    def __call__(self, items):
        if self.tok is None: self.setup(items)
        for t in items: yield self.tok.EncodeAsPieces(t)
class SentencepieceTokenizer(BaseTokenizer):
    def __init__(self, model_path: str) -> None:
        from sentencepiece import SentencePieceProcessor
        super().__init__(name="sentencepiece")
        self._tokenizer = SentencePieceProcessor()
        self._tokenizer.load(model_path)

    def tokenize(self, text: str) -> List[Token]:
        result = []
        for subword in self._tokenizer.EncodeAsPieces(text):
            token = Token(surface=subword)
            result.append(token)
        return result
class SubwordTokenizer(Tokenizer):
    def __init__(self,
                 model_path: str = None,
                 nbest_size: int = None,
                 alpha: float = None):
        self._model_path = cached_path(model_path)
        self._processor = SentencePieceProcessor()
        self._processor.Load(self._model_path)
        self._nbest_size = nbest_size
        self._alpha = alpha

    def tokenize(self, text: str) -> List[Token]:
        if self._nbest_size and self._alpha:
            subwords = self._processor.SampleEncodeAsPieces(text, self._nbest_size, self._alpha)
        else:
            subwords = self._processor.EncodeAsPieces(text)
        tokens = [Token(s) for s in subwords]
        return tokens

    def batch_tokenize(self, texts: List[str]) -> List[List[Token]]:
        return [self.tokenize(text) for text in texts]
class SentencepieceFasttextEmbed(EmbedderInterface):
    class Config(EmbedderInterface.Config):
        pass

    @classmethod
    def from_config(cls, config: Config):
        spm_model_file = os.path.join(config.preproc_dir, "spm.model")
        fasttext_model_file = os.path.join(config.preproc_dir,
                                           "fasttext-model.bin")
        return cls(spm_model_file, fasttext_model_file, config.max_pieces)

    def __init__(self,
                 spm_model_file: str,
                 fasttext_model_file: str = '',
                 max_pieces: int = -1):
        super().__init__(max_pieces=max_pieces)

        self.spm = SentencePieceProcessor()
        self.spm.Load(spm_model_file)
        self.pad_idx = self.spm.pad_id()
        self.pad_token = self.spm.IdToPiece(self.pad_idx)
        self.unk_idx = self.spm.unk_id()
        self.unk_token = self.spm.IdToPiece(self.unk_idx)
        self.bos_idx = self.spm.bos_id()
        self.bos_token = self.spm.IdToPiece(self.bos_idx)
        self.eos_idx = self.spm.eos_id()
        self.eos_token = self.spm.IdToPiece(self.eos_idx)

        if fasttext_model_file:
            self.fasttext = fasttext.load_model(fasttext_model_file)

    @property
    def embed_dim(self):
        return self.fasttext.dim

    @property
    def n_vocab(self):
        return self.spm.get_piece_size()

    def encode_text_as_ids(self, text: str) -> np.array:
        """
    Doesn't produce BOS, EOS ids.
    """
        return np.asarray(self.spm.EncodeAsIds(text)[self.pieces_slice],
                          dtype=np.int32)

    def encode_text_as_tokens(self, text: str) -> List[str]:
        """
    Doesn't produce BOS, EOS tokens.
    """
        return self.spm.EncodeAsPieces(text)[self.pieces_slice]

    def tokenize(self, text: str) -> List[str]:
        """
    Alias for `encode_text_as_tokens`.
    Doesn't produce BOS, EOS tokens.
    """
        return self.encode_text_as_tokens(text)[self.pieces_slice]

    def decode_ids_as_text(self, ids: List[int], strip_special=True) -> str:
        """
    Doesn't produce PAD, BOS, or EOS text.
    i.e. PAD, BOS, EOS ids are stripped out before decoding.
    UNK is decoded but unintelligible.
    """
        if strip_special:
            ids = [
                int(id) for id in ids
                if id not in (self.pad_idx, self.bos_idx, self.eos_idx)
            ]
        else:
            ids = [int(id) for id in ids]
        return self.spm.DecodeIds(ids)

    def decode_tokens_as_text(self, toks: List[str]) -> str:
        """
    Doesn't produce PAD, BOS, or EOS text.
    i.e. PAD, BOS, EOS tokens are stripped out before decoding.
    UNK is decoded but unintelligible.
    """
        return self.spm.DecodePieces(toks[self.pieces_slice])

    @functools.lru_cache(maxsize=1024)
    def decode_id_as_token(self, id: int) -> str:
        return self.spm.IdToPiece(id)

    def decode_ids_as_tokens(self,
                             ids: List[int],
                             strip_special: bool = True) -> List[str]:
        """
    By default, doesn't produce PAD, BOS, EOS tokens.

    Avoids problematic intermediate string representation that causes length mismatch.
    In other words, SentencePiece isn't isomorphic with respect to the string representation.
    """
        if strip_special:
            ids = [
                id for id in ids
                if id not in (self.pad_idx, self.bos_idx, self.eos_idx)
            ]
        return [self.decode_id_as_token(int(ix)) for ix in ids]

    @functools.lru_cache(maxsize=1024)
    def embed_tok(self, tok: str) -> np.array:
        """
    When given PAD, returns all zeros
    """
        if tok == self.pad_token:
            return np.zeros(self.fasttext.dim)
        return np.asarray(self.fasttext[tok])

    def embed_text(self, text: str) -> np.array:
        """
    Doesn't produce PAD, BOS, EOS embeddings.
    i.e. PAD, BOS, EOS are stripped out during tokenization before embedding.
    """
        return np.asarray([self.embed_tok(tok) for tok in self.tokenize(text)])

    def embed_ids(self,
                  ids: List[int],
                  strip_special: bool = True) -> List[np.array]:
        """
    By default, doesn't produce PAD, BOS, EOS tokens.

    Avoids problematic intermediate string representation that causes length mismatch.
    In other words, SentencePiece isn't isomorphic with respect to the string representation.
    """
        return [
            self.embed_tok(t)
            for t in self.decode_ids_as_tokens(ids,
                                               strip_special=strip_special)
        ]

    def embed_ids_batch(self, ids: np.array) -> torch.tensor:
        emb = [self.embed_ids(turn, strip_special=False) for turn in ids]
        emb = torch.tensor(emb)
        return emb
Exemple #9
0
def main(input_dir,
         subword_model_path,
         output_dir,
         max_source_subwords,
         max_target_subwords,
         source_suffix,
         target_suffix,
         lowercase=False):
    processor = SentencePieceProcessor()
    processor.Load(subword_model_path)

    os.makedirs(output_dir, exist_ok=True)
    train_source_file = os.path.join(output_dir,
                                     "train.{}".format(source_suffix))
    train_target_file = os.path.join(output_dir,
                                     "train.{}".format(target_suffix))
    val_source_file = os.path.join(output_dir, "val.{}".format(source_suffix))
    val_target_file = os.path.join(output_dir, "val.{}".format(target_suffix))
    test_source_file = os.path.join(output_dir,
                                    "test.{}".format(source_suffix))
    test_target_file = os.path.join(output_dir,
                                    "test.{}".format(target_suffix))

    dirs = list(os.listdir(input_dir))
    tasks = []
    for d in dirs:
        if d.startswith("_"):
            continue
        mode = d.lower()
        parse = MODES.get(mode, None)
        assert parse is not None
        tasks.append((os.path.join(input_dir, d), mode, parse))

    files = (("train.jsonl", train_source_file, train_target_file),
             ("val.jsonl", val_source_file, val_target_file),
             ("test.jsonl", test_source_file, test_target_file))
    for orig_file_name, source_file_name, target_file_name in files:
        records = []
        for d, mode, parse in tasks:
            if orig_file_name != "test.jsonl" and mode == "lidirus":
                continue
            elif orig_file_name == "test.jsonl" and mode == "lidirus":
                path = os.path.join(d, "LiDiRuS.jsonl")
            else:
                path = os.path.join(d, orig_file_name)
            for record in parse(path):
                source = mode + SEPARATOR + str(
                    record["idx"]) + SEPARATOR + record["source"]
                target = record["target"]
                if lowercase:
                    source = source.lower()
                    target = target.lower()
                source_subwords = processor.EncodeAsPieces(source)
                if max_source_subwords:
                    source_subwords = source_subwords[:max_source_subwords]
                target_subwords = processor.EncodeAsPieces(target)
                if max_target_subwords:
                    target_subwords = target_subwords[:max_target_subwords]
                source = " ".join(source_subwords)
                target = " ".join(target_subwords)
                records.append((source, target))
        random.shuffle(records)
        with open(source_file_name,
                  "w") as source_file, open(target_file_name,
                                            "w") as target_file:
            for source, target in records:
                source_file.write(source + "\n")
                target_file.write(target + "\n")
Exemple #10
0
def main():
    options = parse_args()
    torch.manual_seed(options.seed)
    basename = os.path.splitext(os.path.basename(options.input))[0]
    out_dir = options.out_dir or "data/{}/".format(basename)
    spinner = Halo(spinner="dots", placement="right")

    with open(options.input, "r", encoding="utf8") as fd:
        reader = csv.reader(fd, delimiter="\t", quoting=csv.QUOTE_NONE, quotechar="")
        lines = [[line[0]] for line in reader]

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    output_full = os.path.join(out_dir, "{}.tsv".format(basename))
    with open(output_full, "w", encoding="utf8") as fd:
        writer = csv.writer(fd, delimiter="\t", quoting=csv.QUOTE_NONE, quotechar="")
        writer.writerows(lines)

    vocab_size = 32000
    spiece_out = os.path.join(out_dir, "spiece")
    spiece_args = (
        "--input={} "
        "--model_prefix={} "
        "--vocab_size={} "
        "--character_coverage=1.0"
    ).format(output_full, spiece_out, vocab_size)
    SentencePieceTrainer.Train(spiece_args)
    # Load the generated vocabulary
    with open("{}.vocab".format(spiece_out), "r", encoding="utf8") as fd:
        reader = csv.reader(
            fd, delimiter="\t", quoting=csv.QUOTE_NONE, quotechar=""
        )
        vocab = [line[0] for line in reader]
    # Remove the special tokens <unk>, <s>, </s>
    vocab = vocab[3:]

    # Convert to BERT style
    bert_vocab = [
        v[1:] if v.startswith("▁") else "##{}".format(v) for v in vocab if v != "▁"
    ]
    # Add BERT's special tokens to the beginning
    bert_vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] + bert_vocab
    # Fill up with unused tokens
    pad_size = vocab_size - len(bert_vocab)
    bert_vocab += ["unused{}".format(i) for i in range(pad_size)]
    with open(os.path.join(out_dir, "vocab.txt"), "w", encoding="utf8") as fd:
        writer = csv.writer(
            fd, delimiter="\t", quoting=csv.QUOTE_NONE, quotechar=""
        )
        writer.writerows([[b] for b in bert_vocab])

    # Convert to GPT-2 style
    # Unfortunately it's slow and tedious.
    spinner.start(text="Generating BPE vocabulary")
    gpt2_vocab = ["Ġ{}".format(v[1:]) if v.startswith("▁") else v for v in vocab]
    # Add the GPT-2 special token to the end
    gpt2_vocab.append("<|endoftext|>")
    with open(os.path.join(out_dir, "vocab.json"), "w", encoding="utf8") as fd:
        json.dump({v: i for i, v in enumerate(gpt2_vocab)}, fd, ensure_ascii=False)
    spiece_processor = SentencePieceProcessor()
    spiece_processor.Load("{}.model".format(spiece_out))
    # Encode the whole text
    encoded = [
        [" ".join(spiece_processor.EncodeAsPieces(line[0])).replace("▁", "Ġ")]
        for line in lines
    ]
    tmp_encoded_fd, tmp_encoded_path = tempfile.mkstemp()
    tmp_bpe_fd, tmp_bpe_path = tempfile.mkstemp()
    try:
        # Write the encoded text to a temporary file.
        with os.fdopen(tmp_encoded_fd, "w", encoding="utf8") as fd:
            writer = csv.writer(
                fd, delimiter="\t", quoting=csv.QUOTE_NONE, quotechar=""
            )
            writer.writerows(encoded)
        learn_bpe(
            open(tmp_encoded_path, "r", encoding="utf8"),
            open(tmp_bpe_path, "w", encoding="utf8"),
            num_symbols=vocab_size,
        )
        with open(tmp_bpe_path, "r", encoding="utf8") as fd:
            reader = csv.reader(
                fd, delimiter="\t", quoting=csv.QUOTE_NONE, quotechar=""
            )
            seen = set()
            merges = []
            for line in reader:
                # Get rid of the </w> tokens
                line = line[0].replace("</w>", "")
                # Remove duplicates (due to </w> tokens)
                if line not in seen:
                    seen.add(line)
                    merges.append([line])
        with open(os.path.join(out_dir, "merges.txt"), "w", encoding="utf8") as fd:
            writer = csv.writer(
                fd, delimiter="\t", quoting=csv.QUOTE_NONE, quotechar=""
            )
            writer.writerows(merges)
    finally:
        os.remove(tmp_encoded_path)
        os.remove(tmp_bpe_path)
    spinner.stop()
Exemple #11
0
def convert_subword(transcript: str, sp: spm.SentencePieceProcessor):
    text = " ".join(sp.EncodeAsPieces(transcript))
    label = " ".join([str(sp.PieceToId(token)) for token in text])
    return text, label
class BPE_Dictionary(object):
    def __init__(
        self,
        dict,
        dict_type,
        pad=constants.PAD,
        eos=constants.EOS,
        unk=constants.UNK,
        bos=constants.BOS,
    ):
        self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
        self.dict = os.path.expanduser(dict)
        self.dict_type = dict_type

        if self.dict_type == SENTENCEPIECE:
            assert self.exists(self.dict, self.dict_type)
            self.bpe_dict = SentencePieceProcessor()
            self.bpe_dict.load(f'{self.dict}.model')
            self.pad_index = self.bpe_dict.pad_id()
            self.bos_index = self.bpe_dict.bos_id()
            self.eos_index = self.bpe_dict.eos_id()
            self.unk_index = self.bpe_dict.unk_id()

    @staticmethod
    def exists(dict, dict_type='sentencepiece'):
        dict = os.path.expanduser(dict)
        if dict_type == SENTENCEPIECE:
            dict_file = f'{dict}.model'
            vocab_file = f'{dict}.vocab'
            if os.path.exists(dict_file) and os.path.exists(vocab_file):
                return True
            else:
                return False
        else:
            raise NotImplementedError

    def save(self, dict_name):
        dict_name = os.path.expanduser(dict_name)
        os.makedirs(os.path.dirname(dict_name), exist_ok=True)
        if self.dict_type == SENTENCEPIECE:
            shutil.copy(f'{self.dict}.model', f'{dict_name}.model')
            shutil.copy(f'{self.dict}.vocab', f'{dict_name}.vocab')
        else:
            raise NotImplementedError

    def encode_tokens(self, sentence):
        return self.bpe_dict.EncodeAsPieces(sentence)

    def encode_ids(self, sentence):
        return self.bpe_dict.EncodeAsIds(sentence)

    def string(self,
               tensor: torch.Tensor,
               bpe_symbol=None,
               escape_unk=None,
               trunc_eos=None):
        if torch.is_tensor(tensor) and tensor.dim() == 2:
            return "\n".join(
                self.string(t, bpe_symbol, escape_unk, trunc_eos)
                for t in tensor)
        return self.bpe_dict.Decode(tensor.tolist())

    def __getitem__(self, idx):
        return self.bpe_dict.IdToPiece(idx)

    def __contains__(self, sym):
        return self.index(sym) != self.unk()

    def index(self, sym):
        return self.bpe_dict[sym]

    def __len__(self):
        return len(self.bpe_dict)

    def bos(self):
        """Helper to get index of beginning-of-sentence symbol"""
        return self.bos_index

    def pad(self):
        """Helper to get index of pad symbol"""
        return self.pad_index

    def eos(self):
        """Helper to get index of end-of-sentence symbol"""
        return self.eos_index

    def unk(self):
        """Helper to get index of unk symbol"""
        return self.unk_index