def test_sequence_preproc_module_bert_tokenizer(): metadata = { "preprocessing": { "lowercase": True, "tokenizer": "bert", "unknown_symbol": "<UNK>", "padding_symbol": "<PAD>", "computed_fill_value": "<UNK>", }, "max_sequence_length": SEQ_SIZE, "str2idx": { "<EOS>": 0, "<SOS>": 1, "<PAD>": 2, "<UNK>": 3, "hello": 4, "world": 5, "pale": 7, "##ont": 8, "##ology": 9, }, } module = _SequencePreprocessing(metadata) res = module([ "paleontology", "unknown", "hello world hello", "hello world hello world" ]) assert torch.allclose( res, torch.tensor([[1, 7, 8, 9, 0, 2], [1, 3, 0, 2, 2, 2], [1, 4, 5, 4, 0, 2], [1, 4, 5, 4, 5, 0]]))
def test_text_preproc_module_space_punct_tokenizer(): metadata = { "preprocessing": { "lowercase": True, "tokenizer": "space_punct", "unknown_symbol": "<UNK>", "padding_symbol": "<PAD>", "computed_fill_value": "<UNK>", }, "max_sequence_length": SEQ_SIZE, "str2idx": { "<EOS>": 0, "<SOS>": 1, "<PAD>": 2, "<UNK>": 3, "this": 4, "sentence": 5, "has": 6, "punctuation": 7, ",": 8, ".": 9, }, } module = _SequencePreprocessing(metadata) res = module( ["punctuation", ",,,,", "this... this... punctuation", "unknown"]) assert torch.allclose( res, torch.tensor([[1, 7, 0, 2, 2, 2], [1, 8, 8, 8, 8, 0], [1, 4, 9, 9, 9, 4], [1, 3, 0, 2, 2, 2]]))
def test_sequence_preproc_module_bad_tokenizer(): metadata = { "preprocessing": { "lowercase": True, "tokenizer": "dutch_lemmatize", "unknown_symbol": "<UNK>", "padding_symbol": "<PAD>", "computed_fill_value": "<UNK>", }, "max_sequence_length": SEQ_SIZE, "str2idx": { "<EOS>": 0, "<SOS>": 1, "<PAD>": 2, "<UNK>": 3, "▁hell": 4, "o": 5, "▁world": 6 }, } with pytest.raises(ValueError): _SequencePreprocessing(metadata)
def create_preproc_module(metadata: Dict[str, Any]) -> torch.nn.Module: return _SequencePreprocessing(metadata)