def get_recurrent_tokenizer(vocab, max_context_tokens, unk_token, pad_token, device="cpu"): """ Return a tokenizer to be used with recurrent-based models """ question_tokenizer = Tokenizer(WordLevel(vocab, unk_token=unk_token)) question_tokenizer.normalizer = Sequence( [StripAccents(), Lowercase(), Strip()]) question_tokenizer.pre_tokenizer = PreSequence( [Whitespace(), Punctuation()]) question_tokenizer.enable_padding(direction="right", pad_id=vocab[pad_token], pad_type_id=1, pad_token=pad_token) context_tokenizer = Tokenizer(WordLevel(vocab, unk_token=unk_token)) context_tokenizer.normalizer = Sequence( [StripAccents(), Lowercase(), Strip()]) context_tokenizer.pre_tokenizer = PreSequence( [Whitespace(), Punctuation()]) context_tokenizer.enable_padding( direction="right", pad_id=vocab[pad_token], pad_type_id=1, pad_token=pad_token, ) context_tokenizer.enable_truncation(max_context_tokens) return RecurrentSquadTokenizer(question_tokenizer, context_tokenizer, device=device)
def test_instantiate(self): assert Punctuation() is not None assert Punctuation("removed") is not None assert isinstance(Punctuation(), PreTokenizer) assert isinstance(Punctuation(), Punctuation) assert isinstance(pickle.loads(pickle.dumps(Punctuation())), Punctuation)
def test_bert_like(self): pre_tokenizer = Sequence([WhitespaceSplit(), Punctuation()]) assert isinstance(Sequence([]), PreTokenizer) assert isinstance(Sequence([]), Sequence) assert isinstance(pickle.loads(pickle.dumps(pre_tokenizer)), Sequence) result = pre_tokenizer.pre_tokenize_str("Hey friend! How are you?!?") assert result == [ ("Hey", (0, 3)), ("friend", (4, 10)), ("!", (10, 11)), ("How", (16, 19)), ("are", (20, 23)), ("you", (24, 27)), ("?", (27, 28)), ("!", (28, 29)), ("?", (29, 30)), ]