def test_decoder_remove_blanks(tfm: BatchTextTransformer, blank_input):
    out = tfm.decode_prediction(blank_input.argmax(1))

    assert len(out) == 1
    assert isinstance(out, list)
    assert type(out[0]) is str
    assert out[0] == ""
def test_encode_text(tfm: BatchTextTransformer):
    # Skip this test if the module is exported for inference
    # as it only contains the decoding part.
    if isinstance(tfm, torch.jit.ScriptModule):
        return

    encoded, encoded_lens = tfm.encode(["hello world", "oi"],
                                       return_length=True)
    assert len(encoded) == 2
    assert len(encoded_lens) == 2
    expected = torch.Tensor([29, 8, 5, 12, 12, 15, 0, 23, 15, 18, 12, 4, 30])
    assert (encoded[0] == expected).all()
    assert encoded_lens[0] == 13
    assert encoded_lens[1] == 4

    encoded2 = tfm.encode(["hello world", "oi"], return_length=False)
    assert (encoded == encoded2).all()
def test_decoder_repeat_same_element(tfm: BatchTextTransformer, blank_input):
    a_idx = tfm.vocab.stoi["a"]
    blank_input[:, a_idx, :10] = 2  # a
    blank_input[:, a_idx, 15:20] = 2  # a
    out = tfm.decode_prediction(blank_input.argmax(1))

    assert len(out) == 1
    assert isinstance(out, list)
    assert type(out[0]) is str
    assert out[0] == "aa"
def test_decoder_simple_sequence(tfm: BatchTextTransformer, blank_input):
    a_idx = tfm.vocab.stoi["a"]
    b_idx = tfm.vocab.stoi["b"]
    blank_input[:, a_idx, :10] = 2  # a
    blank_input[:, b_idx, 15:20] = 2  # b

    out = tfm.decode_prediction(blank_input.argmax(1))

    assert len(out) == 1
    assert isinstance(out, list)
    assert type(out[0]) is str
    assert out[0] == "ab"
Пример #5
0
    def build_text_transform(
        self, initial_vocab_tokens: List[str]
    ) -> BatchTextTransformer:
        """Overwrite this function if you want to change how the text processing happens inside the model.

        Args:
            initial_vocab_tokens : List of tokens to create the vocabulary, special tokens should not be included here.

        Returns:
            The transform that will both `encode` the text and `decode_prediction`.
        """
        vocab = Vocab(initial_vocab_tokens, nemo_compat=False)
        return BatchTextTransformer(vocab=vocab)
Пример #6
0
    def build_text_transform(self, initial_vocab_tokens: List[str],
                             nemo_compat_vocab: bool) -> BatchTextTransformer:
        """Overwrite this function if you want to change how the text processing happens inside the model.

        Args:
            initial_vocab_tokens : List of tokens to create the vocabulary, special tokens should not be included here.
            nemo_compat_vocab : Controls if the used vocabulary will be compatible with the original nemo implementation.

        Returns:
            The transform that will both `encode` the text and `decode_prediction`.
        """
        vocab = Vocab(initial_vocab_tokens, nemo_compat=nemo_compat_vocab)
        return BatchTextTransformer(vocab=vocab)
def tfm(simple_vocab, request):
    transform = BatchTextTransformer(vocab=simple_vocab)
    if request.param:
        return torch.jit.script(transform)
    return transform
Пример #8
0
 def build_text_pipeline(self, initial_vocab_tokens: List[str],
                         nemo_compat_vocab: bool) -> BatchTextTransformer:
     vocab = Vocab(initial_vocab_tokens, nemo_compat=nemo_compat_vocab)
     return BatchTextTransformer(vocab=vocab)