class NemoBertTokenizer(TokenizerSpec): def __init__(self, pretrained_model=None, vocab_file=None, do_lower_case=True, max_len=None, do_basic_tokenize=True, never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): if pretrained_model: self.tokenizer = BertTokenizer.from_pretrained(pretrained_model) if "uncased" not in pretrained_model: self.tokenizer.basic_tokenizer.do_lower_case = False else: self.tokenizer = BertTokenizer(vocab_file, do_lower_case, max_len, do_basic_tokenize, never_split) self.vocab_size = len(self.tokenizer.vocab) self.never_split = never_split def text_to_tokens(self, text): tokens = self.tokenizer.tokenize(text) return tokens def tokens_to_text(self, tokens): text = self.tokenizer.convert_tokens_to_string(tokens) return remove_spaces(handle_quotes(text.strip())) def token_to_id(self, token): return self.tokens_to_ids([token])[0] def tokens_to_ids(self, tokens): ids = self.tokenizer.convert_tokens_to_ids(tokens) return ids def ids_to_tokens(self, ids): tokens = self.tokenizer.convert_ids_to_tokens(ids) return tokens def text_to_ids(self, text): tokens = self.text_to_tokens(text) ids = self.tokens_to_ids(tokens) return ids def ids_to_text(self, ids): tokens = self.ids_to_tokens(ids) tokens_clean = [t for t in tokens if t not in self.never_split] text = self.tokens_to_text(tokens_clean) return text def pad_id(self): return self.tokens_to_ids(["[PAD]"])[0] def bos_id(self): return self.tokens_to_ids(["[CLS]"])[0] def eos_id(self): return self.tokens_to_ids(["[SEP]"])[0]
class BertBPE(object): @staticmethod def add_args(parser): # fmt: off parser.add_argument('--bpe-cased', action='store_true', help='set for cased BPE', default=False) parser.add_argument('--bpe-vocab-file', type=str, help='bpe vocab file.') # fmt: on def __init__(self, args): try: from pytorch_transformers import BertTokenizer from pytorch_transformers.tokenization_utils import clean_up_tokenization except ImportError: raise ImportError( 'Please install 1.0.0 version of pytorch_transformers' 'with: pip install pytorch-transformers') if 'bpe_vocab_file' in args: self.bert_tokenizer = BertTokenizer( args.bpe_vocab_file, do_lower_case=not args.bpe_cased) else: vocab_file_name = 'bert-base-cased' if args.bpe_cased else 'bert-base-uncased' self.bert_tokenizer = BertTokenizer.from_pretrained( vocab_file_name) self.clean_up_tokenization = clean_up_tokenization def encode(self, x: str) -> str: return ' '.join(self.bert_tokenizer.tokenize(x)) def decode(self, x: str) -> str: return self.clean_up_tokenization( self.bert_tokenizer.convert_tokens_to_string(x.split(' '))) def is_beginning_of_word(self, x: str) -> bool: return not x.startswith('##')