예제 #1
0
def test_sequence_preproc_module_bert_tokenizer():
    metadata = {
        "preprocessing": {
            "lowercase": True,
            "tokenizer": "bert",
            "unknown_symbol": "<UNK>",
            "padding_symbol": "<PAD>",
            "computed_fill_value": "<UNK>",
        },
        "max_sequence_length": SEQ_SIZE,
        "str2idx": {
            "<EOS>": 0,
            "<SOS>": 1,
            "<PAD>": 2,
            "<UNK>": 3,
            "hello": 4,
            "world": 5,
            "pale": 7,
            "##ont": 8,
            "##ology": 9,
        },
    }
    module = _SequencePreprocessing(metadata)

    res = module([
        "paleontology", "unknown", "hello world hello",
        "hello world hello world"
    ])

    assert torch.allclose(
        res,
        torch.tensor([[1, 7, 8, 9, 0, 2], [1, 3, 0, 2, 2, 2],
                      [1, 4, 5, 4, 0, 2], [1, 4, 5, 4, 5, 0]]))
예제 #2
0
def test_text_preproc_module_space_punct_tokenizer():
    metadata = {
        "preprocessing": {
            "lowercase": True,
            "tokenizer": "space_punct",
            "unknown_symbol": "<UNK>",
            "padding_symbol": "<PAD>",
            "computed_fill_value": "<UNK>",
        },
        "max_sequence_length": SEQ_SIZE,
        "str2idx": {
            "<EOS>": 0,
            "<SOS>": 1,
            "<PAD>": 2,
            "<UNK>": 3,
            "this": 4,
            "sentence": 5,
            "has": 6,
            "punctuation": 7,
            ",": 8,
            ".": 9,
        },
    }
    module = _SequencePreprocessing(metadata)

    res = module(
        ["punctuation", ",,,,", "this... this... punctuation", "unknown"])

    assert torch.allclose(
        res,
        torch.tensor([[1, 7, 0, 2, 2, 2], [1, 8, 8, 8, 8, 0],
                      [1, 4, 9, 9, 9, 4], [1, 3, 0, 2, 2, 2]]))
예제 #3
0
def test_sequence_preproc_module_bad_tokenizer():
    metadata = {
        "preprocessing": {
            "lowercase": True,
            "tokenizer": "dutch_lemmatize",
            "unknown_symbol": "<UNK>",
            "padding_symbol": "<PAD>",
            "computed_fill_value": "<UNK>",
        },
        "max_sequence_length": SEQ_SIZE,
        "str2idx": {
            "<EOS>": 0,
            "<SOS>": 1,
            "<PAD>": 2,
            "<UNK>": 3,
            "▁hell": 4,
            "o": 5,
            "▁world": 6
        },
    }

    with pytest.raises(ValueError):
        _SequencePreprocessing(metadata)
예제 #4
0
 def create_preproc_module(metadata: Dict[str, Any]) -> torch.nn.Module:
     return _SequencePreprocessing(metadata)