def main():
    # Instantiate argument parser
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--train_data_file",
        default=None,
        type=str,
        required=True,
        help=
        "The input training data file or a path to a directory with multiple training data files."
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        required=True,
        help="The output directory where the tokenizer model will be written.")
    # Optional parameters
    parser.add_argument("--vocab_size",
                        default=5000,
                        type=int,
                        help="Vocabulary maximum size, default 5000.")
    parser.add_argument("--min_freq",
                        default=2,
                        type=int,
                        help="Minimum number of occurrences, default 2")

    # Generate args
    args = parser.parse_args()

    # Initialize a tokenizer
    tokenizer = ByteLevelBPETokenizer()

    # Get training files
    paths = os.path.abspath(args.train_data_file)
    if not args.train_data_file.endswith(".txt"):
        paths = [str(x) for x in Path(paths).glob("**/*.txt")]

    # Customize training
    tokenizer.train(files=paths,
                    vocab_size=args.vocab_size,
                    min_frequency=args.min_freq,
                    special_tokens=[
                        "<s>",
                        "<pad>",
                        "</s>",
                        "<unk>",
                        "<mask>",
                    ])

    tokenizer.add_special_tokens(["<x>", "<z>"])

    # Save files to disk
    output_dir = os.path.abspath(args.output_dir)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    tokenizer.save_model(output_dir)
def load_tokenizer(vocab='./tokenizer/vocab.json', merges='./tokenizer/merges.txt', gpt=False, load_from=None):
    if gpt:
        if load_from:
            tokenizer = GPT2Tokenizer.from_pretrained(load_from)
        else:
            tokenizer = GPT2Tokenizer(
                vocab, merges, 
                bos_token=CARD_BEGIN, eos_token=CARD_END, sep_token=CARD_END,
                unk_token=UNK, pad_token=CARD_PAD, mask_token=CARD_MASK, padding_side="left"
            )
    else:
        tokenizer = ByteLevelBPETokenizer(vocab, merges)
        tokenizer.add_special_tokens(SPECIAL_TOKENS + OTHER_TOKENS)
        tokenizer.mask_token = CARD_MASK
    
    tokenizer.pre_tokenizer = Whitespace()
    return tokenizer
class ByteBPETokenizer:
    def __init__(self, vocab_json, merge_txt, max_length=750):
        self.tokenizer = ByteLevelBPETokenizer(vocab_json, merge_txt)
        self.tokenizer.enable_truncation(max_length=max_length)
        self.tokenizer.enable_padding(max_length=max_length)
        self.tokenizer.add_special_tokens(["[PAD]", "[CLS]"])
        # self.tokenizer.post_processor = RobertaProcessing(("</s>", 2), ("<s>", 1))
        # self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

    def encode(self, review):
        review = clean_sentence(review)
        encoded = self.tokenizer.encode(review.lower())
        # pp_encoded = self.tokenizer.post_process(encoded)
        return encoded

    def tokenize2Index(self, review, should_stem=False):
        encoded = self.encode(review)

        return encoded.ids

    def trainBPE(self, paths, vocab_size=30000, min_frequency=10, special_tokens=["[PAD]", "[CLS]"]):
        tokenizer = ByteLevelBPETokenizer()
        tokenizer.train(files=paths, vocab_size=vocab_size, min_frequency=min_frequency, special_tokens=special_tokens)
        tokenizer.save("yelp_bpe/", "yelp-bpe")
Example #4
0
def create_tokenizer(corpus_file_path, vocab_size):
    tokenizer = ByteLevelBPETokenizer()
    tokenizer.train(corpus_file_path, vocab_size)
    tokenizer.add_special_tokens(['<SOS>', '<PAD>', '<EOS>'])
    return tokenizer
Example #5
0
class HuggingFaceBpeHelper(BPEHelper):
    """
    HuggingFace's ByteLevelBPE Tokenizer.

    Fast because Rust.
    """

    def __init__(self, opt: Opt, shared: TShared = None):
        super().__init__(opt, shared)
        # Default true for HF
        self.special_tok_map = {}  # map from HF
        self.add_prefix_space = opt.get('bpe_add_prefix_space', True)
        if self.add_prefix_space is None:
            self.add_prefix_space = True
        if opt.get('dict_loaded'):
            dfname = opt['dict_file']
            if PathManager.exists(f'{dfname}-merges.txt'):
                opt['bpe_merge'] = f'{dfname}-merges.txt'
            if PathManager.exists(f'{dfname}-vocab.json'):
                opt['bpe_vocab'] = f'{dfname}-vocab.json'
        try:
            from tokenizers import ByteLevelBPETokenizer
        except ImportError:
            raise ImportError(
                'Please install HuggingFace tokenizer with: pip install tokenizers'
            )

        if self.bpe_dropout:
            raise NotImplementedError(
                '--bpe-dropout is not supported with ByteLevelBPE because tokenizers '
                'library does not allow dynamically turning BPE on/off. You can use '
                '--dict-tokenizer slow_bytelevel_bpe to gain this feature.'
            )

        if self.lower:
            warn_once('Are you sure you want to lower case your BPE dictionary?')
        if self.maxtokens > 0 or self.minfreq > 0:
            raise ValueError(
                'You should not filter vocabulary with using --dict-tokenizer bytelevelbpe'
                ' (no --dict-minfreq or --dict-maxtokens).'
            )
        if 'bpe_vocab' not in opt:
            raise ValueError('--bpe-vocab is required for loading pretrained tokenizer')
        if 'bpe_merge' not in opt:
            raise ValueError('--bpe-merge is required for loading pretrained tokenizer')

        self.vocab_path = opt['bpe_vocab']
        self.merge_path = opt['bpe_merge']

        if not self.vocab_path or not self.merge_path:
            raise IOError(
                '--bpe-vocab and --bpe-merge are mandatory with '
                '--dict-tokenizer bytelevelbpe'
            )

        if not PathManager.exists(self.vocab_path):
            raise IOError(
                f'File {self.vocab_path} does not exist. --bpe-vocab must be pretrained.'
            )
        if not PathManager.exists(self.merge_path):
            raise IOError(
                f'File {self.merge_path} does not exist. --bpe-merge must be pretrained.'
            )

        self.tokenizer = ByteLevelBPETokenizer(
            self.vocab_path, self.merge_path, self.add_prefix_space
        )

    def helper_encode(self, text: str) -> List[str]:
        """
        Decode list of tokens into text string.

        :param tokens:
            list of tokens
        :param delimiter:
            string delimiter for tokens

        :return text:
            decoded text
        """
        return self.tokenizer.encode(text).tokens

    def helper_decode(
        self, tokens: List[str], token_ids: List[int], delimiter: str
    ) -> str:
        """
        Decode list of tokens into text string.

        :param tokens:
            list of tokens
        :param token_ids:
            list of token ids
        :param delimiter:
            string delimiter for tokens

        :return text:
            decoded text
        """
        text = self.tokenizer.decode(token_ids, skip_special_tokens=False)

        return text

    def add_special_tokens(self, dict_agent, special_tokens: List[str]):
        """
        Add special tokens to the tokenizer and dict_agent.
        """
        logging.debug(f'adding the following special tokens: {special_tokens}')
        self.tokenizer.add_special_tokens(special_tokens)  # add to HF

        for tok in special_tokens:
            parlai_key = dict_agent[tok]
            hf_key = self.tokenizer.token_to_id(tok)
            self.special_tok_map[parlai_key] = hf_key

    def sync_with_dict(self, dict_agent):
        """
        Sync the dictionary agent with Hugging Face tokenizer's BPE dict.

        Called only once on initialization.
        """
        special_tokens = [
            dict_agent.null_token,
            dict_agent.start_token,
            dict_agent.end_token,
            dict_agent.unk_token,
        ]
        self.add_special_tokens(dict_agent, special_tokens)

        for i in range(self.tokenizer.get_vocab_size() - len(special_tokens)):
            token = self.tokenizer.id_to_token(i)
            dict_agent.add_token(token)
            # We don't have access to the hugging face word frequency table,
            # just set it to 1 instead
            dict_agent.freq[token] = 1

    def save(self, dir_name: str, file_name: str):
        """
        Save appropriate files.

        :param dir_name:
            directory to save.
        :param file_name:
            file to save.
        """
        self.tokenizer.save_model(dir_name, file_name)
Example #6
0
if TRAIN_BASE:
    # Initialize a tokenizer
    tokenizer = ByteLevelBPETokenizer()

    # Customize training
    tokenizer.train(files=paths, vocab_size=52_000, min_frequency=2, special_tokens=[
        "<s>",
        "<pad>",
        "</s>",
        "<unk>",
        "<mask>",
    ])

    # Save files to disk
    tokenizer.save_model("tokenizer")

inp = 'print("Hello World!")'

tokenizer = GPT2Tokenizer.from_pretrained('tokenizer')

tokenizer.add_special_tokens({
    "bos_token": "<s>",
    "pad_token": "<pad>",
    "eos_token": "</s>",
    "unk_token": "<unk>",
    "mask_token": "<mask>",
})

t = tokenizer.encode(inp)

print(t)
Example #7
0
File: bpe.py Project: nii4u/ParlAI
class HuggingFaceBpeHelper(BPEHelper):
    """
    HuggingFace's ByteLevelBPE Tokenizer.

    Fast because Rust.
    """
    def __init__(self, opt: Opt, shared: TShared = None):
        super().__init__(opt, shared)
        # Default true for HF
        self.add_prefix_space = opt.get('bpe_add_prefix_space', True)
        if self.add_prefix_space is None:
            self.add_prefix_space = True
        if opt.get('dict_loaded'):
            dfname = opt['dict_file']
            if os.path.isfile(f'{dfname}-merges.txt'):
                opt['bpe_merge'] = f'{dfname}-merges.txt'
            if os.path.isfile(f'{dfname}-vocab.json'):
                opt['bpe_vocab'] = f'{dfname}-vocab.json'
        try:
            from tokenizers import ByteLevelBPETokenizer
        except ImportError:
            raise ImportError(
                'Please install HuggingFace tokenizer with: pip install tokenizers'
            )

        if self.lower:
            raise ValueError(
                'Only use --dict-lower false with --dict-tokenizer bytelevelbpe'
            )
        if self.maxtokens > 0 or self.minfreq > 0:
            raise ValueError(
                'You should not filter vocabulary with using --dict-tokenizer bytelevelbpe'
                ' (no --dict-minfreq or --dict-maxtokens).')
        if 'bpe_vocab' not in opt:
            raise ValueError(
                '--bpe-vocab is required for loading pretrained tokenizer')
        if 'bpe_merge' not in opt:
            raise ValueError(
                '--bpe-merge is required for loading pretrained tokenizer')

        self.vocab_path = opt['bpe_vocab']
        self.merge_path = opt['bpe_merge']

        if not self.vocab_path or not self.merge_path:
            raise IOError('--bpe-vocab and --bpe-merge are mandatory with '
                          '--dict-tokenizer bytelevelbpe')

        if not os.path.isfile(self.vocab_path):
            raise IOError(
                f'File {self.vocab_path} does not exist. --bpe-vocab must be pretrained.'
            )
        if not os.path.isfile(self.merge_path):
            raise IOError(
                f'File {self.merge_path} does not exist. --bpe-merge must be pretrained.'
            )

        self.tokenizer = ByteLevelBPETokenizer(self.vocab_path,
                                               self.merge_path,
                                               self.add_prefix_space)

    def helper_encode(self, text: str) -> List[str]:
        """
        Decode list of tokens into text string.

        :param tokens:
            list of tokens
        :param delimiter:
            string delimiter for tokens

        :return text:
            decoded text
        """
        return self.tokenizer.encode(text).tokens

    def helper_decode(self, tokens: List[str], token_ids: List[int],
                      delimiter: str) -> str:
        """
        Decode list of tokens into text string.

        :param tokens:
            list of tokens
        :param token_ids:
            list of token ids
        :param delimiter:
            string delimiter for tokens

        :return text:
            decoded text
        """
        text = self.tokenizer.decode(token_ids)
        return text

    def sync_with_dict(self, dict_agent):
        """
        Sync the dictionary agent with Hugging Face tokenizer's BPE dict.

        Called only once on initialization.
        """
        special_tokens = [
            dict_agent.null_token,
            dict_agent.start_token,
            dict_agent.end_token,
            dict_agent.unk_token,
        ]
        self.tokenizer.add_special_tokens(special_tokens)
        for i in range(self.tokenizer.get_vocab_size() - 4):
            token = self.tokenizer.id_to_token(i)
            dict_agent.add_token(token)
            # We don't have access to the hugging face word frequency table,
            # just set it to 1 instead
            dict_agent.freq[token] = 1

    def save(self, dir_name: str, file_name: str):
        """
        Save appropriate files.

        :param dir_name:
            directory to save.
        :param file_name:
            file to save.
        """
        self.tokenizer.save(dir_name, file_name)
Example #8
0
class CodeTrainedBPE_Translation_DataProcessor(DataProcessor, Dataset):
    def __init__(self, task_data, max_src_len=512, max_tgt_len=512):
        """
        This data processor tokenizes and numericalises using a custom byte pair 
        encoding trained on the codeSearchNet train data with full docstrings.
        """
        self.task_data = task_data
        self.max_src_len = max_src_len
        self.max_tgt_len = max_tgt_len
        self.tokenizer = ByteLevelBPETokenizer(
            "/nfs/phd_by_carlos/notebooks/datasets/code_search_net/code_bpe_hugging_32k-vocab.json",
            "/nfs/phd_by_carlos/notebooks/datasets/code_search_net/code_bpe_hugging_32k-merges.txt"
        )
        self.tokenizer.add_special_tokens(["[CLS]", "[SOS]", "[EOS]", "[PAD]"])
        self.SOS = self.tokenizer.encode("[SOS]").ids[0]
        self.EOS = self.tokenizer.encode("[EOS]").ids[0]
        self.PAD = self.tokenizer.encode("[PAD]").ids[0]
        self.CLS = self.tokenizer.encode("[CLS]").ids[0]

        self.__remove_long_samples()

    def __len__(self):
        return len(self.task_data)

    def __getitem__(self, idx):
        src, tgt = self.task_data[idx]
        sample = {'src': self.encode(src), 'tgt': self.encode(tgt)}
        return sample

    @property
    def vocab_size(self):
        return self.tokenizer.get_vocab_size()

    def __remove_long_samples(self):
        for i in tqdm.tqdm(list(reversed(range(len(self.task_data)))),
                           desc="removing long samples"):
            src, tgt = self.task_data[i]
            if len(self.encode(src)) > self.max_src_len or len(
                    self.encode(tgt)) > self.max_tgt_len:
                del self.task_data[i]

    def encode(self, sample):
        """
        sample: str: the input string to encode
        """
        return [self.SOS] + self.tokenizer.encode(sample).ids + [self.EOS]

    def encode_src(self, sample):
        return self.encode(sample)

    def encode_tgt(self, sample):
        return self.encode(sample)

    def encode_to_tensor(self, input_samples):
        """
        input_samples: [str]: one or more strings to convert to a single padded tensor. (Seq_len x batch)
        """
        return pad_sequence([
            torch.Tensor(self.encode(sample)).type(torch.LongTensor)
            for sample in input_samples
        ],
                            padding_value=self.PAD)

    def collate(self, input_samples):
        """
        input_samples: [dict]: these are samples obtained through the _get_item method
        """
        collated_samples = {}
        sample_keys = input_samples[0].keys()
        for key in sample_keys:
            collated_samples[key] = torch.nn.utils.rnn.pad_sequence(
                [
                    torch.Tensor(sample[key]).type(torch.LongTensor)
                    for sample in input_samples
                ],
                padding_value=self.PAD)
        return collated_samples

    def decode(self, ids):
        """
        ids: [int]: ids to decode
        """
        return self.tokenizer.decode(ids)

    def decode_src(self, ids):
        return self.decode(ids)

    def decode_tgt(self, ids):
        return self.decode(ids)

    def validate_prediction(self, numerical_sequence):
        # there are no constraints
        return True

    def prediction_is_complete(self, numerical_sequence):
        return self.EOS in numerical_sequence

    def decode_tensor(self, output_tensor):
        """
        output_tensor: [[int]]: model output (Seq_len x batch)
        """
        batch_first_output_tensor = output_tensor.T
        return [
            self.decode(sequence.cpu().tolist())
            for sequence in batch_first_output_tensor
        ]

    def to_dataloader(self,
                      batch_size,
                      repeat=False,
                      num_workers=4,
                      shuffle=True):
        """
        This function returns an iterable object with all the data batched.
        
        >>> BPE_processor = CodeTrainedBPE_Translation_DataProcessor(validation_pairs, max_tgt_len=100)
        >>> dataloader = BPE_processor.to_dataloader(2)
        
        >>> for i_batch, sample_batched in enumerate(dataloader):
        >>>     print(sample_batched["tgt"])
        >>>     print(BPE_processor.decode_tensor(sample_batched["tgt"]))
        >>>     break
        """
        return DataLoader(self, batch_size=batch_size, num_workers=num_workers,\
                           drop_last=False, collate_fn = self.collate, shuffle=shuffle)

    def save(self, path):
        torch.save(self, path)
Example #9
0
class Parse_Tree_Translation_DataProcessor(Dataset):
    def __init__(
            self,
            task_data,
            max_length=500,
            tokenizer_dir="/nfs/phd_by_carlos/notebooks/datasets/code_search_net/",
            grammar_path="src/tree-sitter/tree-sitter-python/src/grammar.json",
            **kwargs):
        self.task_data = task_data
        self.max_length = max_length
        self.tokenizer = ByteLevelBPETokenizer(
            tokenizer_dir + "code_bpe_hugging_32k-vocab.json",
            tokenizer_dir + "code_bpe_hugging_32k-merges.txt")
        self.tokenizer.add_special_tokens(["[CLS]", "[SOS]", "[EOS]", "[PAD]"])
        self.SOS = self.tokenizer.encode("[SOS]").ids[0]
        self.EOS = self.tokenizer.encode("[EOS]").ids[0]
        self.PAD = self.tokenizer.encode("[PAD]").ids[0]
        self.CLS = self.tokenizer.encode("[CLS]").ids[0]

        with open(grammar_path, "r") as grammar_file:
            self.python_grammar = json.load(grammar_file)

        extra_externals = {
            "_string_start": {
                "type": "PATTERN",
                "value": '"'
            },
            "_string_content": {
                "type": "PATTERN",
                "value": "[A-Za-z0-9 _,.()\/{}!$@'*]*"
            },
            "_string_end": {
                "type": "PATTERN",
                "value": '"'
            },
            "_newline": {
                "type": "BLANK"
            }
        }
        for node_type, member in extra_externals.items():
            self.python_grammar["rules"][node_type] = member

        self.python_parser = Code_Parser(self.python_grammar, "python",
                                         **kwargs)
        self.node_processor = Node_Processor()
        self.tree_vocab, grammar_patterns = get_grammar_vocab(
            self.python_grammar)

        self.tokenizer.add_tokens(["<REDUCE>"])
        for tree_token in sorted(self.tree_vocab):
            if len(self.tokenizer.encode(tree_token).tokens) != 1:
                self.tokenizer.add_tokens([tree_token])

        # filtering the data
        filtered_task_data = []
        for desc, code in self.task_data:
            numerical_code_sequence = self.encode_tgt(code)
            numerical_desc_sequence = self.encode_src(desc)
            token_sequence = self.numerical_to_token_sequence(
                numerical_code_sequence)
            if self.python_parser.is_valid_sequence(token_sequence) and len(
                    token_sequence) <= max_length and len(
                        numerical_desc_sequence) <= max_length:
                filtered_task_data.append((desc, code))
            elif len(token_sequence) > max_length or len(
                    numerical_desc_sequence) > max_length:
                print(
                    f"Sequence too long: src->{len(numerical_desc_sequence)}, tgt->{len(token_sequence)}"
                )
            else:
                print(f"Could not parse and reconstruct: {code}")
        self.task_data = filtered_task_data

    def __len__(self):
        return len(self.task_data)

    def __getitem__(self, idx):
        if idx >= len(self):
            raise IndexError

        src, tgt = self.task_data[idx]
        sample = {'src': self.encode_src(src), 'tgt': self.encode_tgt(tgt)}
        return sample

    @property
    def vocab_size(self):
        return self.tokenizer.get_vocab_size()

    def encode_src(self, desc_str):
        return [self.SOS] + self.tokenizer.encode(desc_str).ids + [self.EOS]

    def encode_tgt(self, code_str):
        code_sequence = self.python_parser.code_to_sequence(code_str)
        numerical_code = []
        for code_token in code_sequence:
            numerical_code += self.tokenizer.encode(code_token).ids
        return [self.SOS] + numerical_code + [self.EOS]

    def decode_src(self, numerical_desc):
        """
        ids: [int]: ids to decode
        """
        return self.tokenizer.decode(ids)

    def numerical_to_token_sequence(self, numerical_code):
        token_sequence = [
            self.tokenizer.decode([token_idx]) for token_idx in numerical_code
            if token_idx not in [self.SOS, self.EOS, self.PAD, self.CLS]
        ]
        return token_sequence

    def decode_tgt(self, numerical_code):
        token_sequence = self.numerical_to_token_sequence(numerical_code)
        partial_tree = self.python_parser.sequence_to_partial_tree(
            token_sequence)
        return self.node_processor.pretty_print(
            partial_tree.root), partial_tree

    def validate_prediction(self, current_prediction):
        #         print(f"validating: {current_prediction}")
        token_sequence = self.numerical_to_token_sequence(current_prediction)
        return self.python_parser.is_valid_sequence(token_sequence)

    def prediction_is_complete(self, current_prediction):
        token_sequence = self.numerical_to_token_sequence(current_prediction)
        return self.python_parser.sequence_to_partial_tree(
            token_sequence).is_complete

    def collate(self, input_samples):
        """
        input_samples: [dict]: these are samples obtained through the _get_item method
        """
        collated_samples = {}
        sample_keys = input_samples[0].keys()
        for key in sample_keys:
            collated_samples[key] = torch.nn.utils.rnn.pad_sequence(
                [
                    torch.Tensor(sample[key]).type(torch.LongTensor)
                    for sample in input_samples
                ],
                padding_value=self.PAD)
        return collated_samples

    def to_dataloader(self, batch_size, num_workers=4, shuffle=True):
        """
        This function returns an iterable object with all the data batched.
        
        >>> BPE_processor = CodeTrainedBPE_Translation_DataProcessor(validation_pairs, max_tgt_len=100)
        >>> dataloader = BPE_processor.to_dataloader(2)
        
        >>> for i_batch, sample_batched in enumerate(dataloader):
        >>>     print(sample_batched["tgt"])
        >>>     print(BPE_processor.decode_tensor(sample_batched["tgt"]))
        >>>     break
        """
        return DataLoader(self, batch_size=batch_size, num_workers=num_workers,\
                           drop_last=False, collate_fn = self.collate, shuffle=shuffle)

    def save(self, path):
        torch.save(self, path)
Example #10
0
class HuggingfaceTokenizerBPE(nn.Module):
    def __init__(self, text_files, dataset_info_path='', config_data=None):
        super().__init__()
        # The default vocab size in the BERT model is 30522. If we want a number larger than that, we will also have to
        # change the BERT configuration.
        vocab_size = 30000
        self.info = f'hug{vocab_size}'

        with open(f'config/data/{config_data}.json') as json_file:
            tokenizer_from = json.load(json_file)['tokenizer_from']

        config_name = config_data if tokenizer_from == "" else tokenizer_from
        print(
            os.path.join(dataset_info_path,
                         f'tokenizer_{config_name}_{vocab_size}-vocab.json'))

        # The loading is only properly implemented starting from version 0.8. However, it makes the system use a lot of
        #  CPU for no reason (it is much slower). Maybe it will be fixed in the future.
        if not os.path.isfile(
                os.path.join(
                    dataset_info_path,
                    f'tokenizer_{config_name}_{vocab_size}-vocab.json')):
            text_files = text_files()
            self.tokenizer = ByteLevelBPETokenizer()
            # Join into a single file. This should NOT be necessary but it does not work properly with a lot of files
            with open('/tmp/text_files.txt', 'wb') as outfile:
                for filename in tqdm(
                        text_files,
                        desc='Joining all files into one for tokenization'):
                    with open(filename, 'rb') as readfile:
                        shutil.copyfileobj(readfile, outfile)
                text_files = '/tmp/text_files.txt'
            self.tokenizer.train(text_files,
                                 vocab_size=vocab_size,
                                 special_tokens=special_tokens)
            self.tokenizer.save(dataset_info_path,
                                f'tokenizer_{config_name}_{vocab_size}')

        # No "else", always load for consistency
        vocab_file = os.path.join(
            dataset_info_path,
            f'tokenizer_{config_name}_{vocab_size}-vocab.json')
        merges_file = os.path.join(
            dataset_info_path,
            f'tokenizer_{config_name}_{vocab_size}-merges.txt')
        self.tokenizer = ByteLevelBPETokenizer(vocab_file=vocab_file,
                                               merges_file=merges_file)
        self.tokenizer.add_special_tokens(special_tokens)

        self.index_special_tokens = {
            tok: self.tokenizer.encode(tok).ids[0]
            for tok in special_tokens
        }

    @property
    def device(self):
        return self._float_tensor.device

    def encode(self, sentence: str):
        output = self.tokenizer.encode(sentence)
        token_ids = output.ids
        tokens = output.tokens
        return torch.tensor(token_ids), tokens

    def decode(self, tokens: torch.LongTensor):
        assert tokens.dim() == 1
        tokens = list(tokens.cpu().numpy())
        sentences = self.tokenizer.decode(tokens)
        return sentences

    def id_to_token(self, token_id):
        if type(token_id) != torch.Tensor:
            token_id = torch.tensor(token_id)
        return self.tokenizer.id_to_token(token_id)

    def token_to_id(self, token):
        assert type(token) == str
        return self.tokenizer.token_to_id(token)

    def __len__(self):
        return self.tokenizer.get_vocab_size()

    # This is simply for PyCharm to find the correct reference to the methods of the class
    def __call__(self, *input, **kwargs) -> typing.Any:
        return super().__call__(*input, **kwargs)