Exemplo n.º 1
0
def test_TextEncoder_encodeOrdinal_outputShapeIsCorrect():
    from utils import load_data, TextCoder

    data = load_data(["test_data/10k.en"], ["en"], quiet=True)
    coder = TextCoder(data["en"], vocabulary_size=2000)
    encoded = coder.encode(["this is a european session"], one_hot=False)
    assert encoded[0].shape == (7, 1)
Exemplo n.º 2
0
def test_TextEncoder_encodeDecode_recoversOriginal():
    from utils import load_data, TextCoder

    data = load_data(["test_data/10k.en"], ["en"], quiet=True)
    coder = TextCoder(data["en"], vocabulary_size=2000)
    encoded = coder.encode(["this is it"], one_hot=False)
    decoded = coder.decode(encoded, one_hot=False)
    assert decoded[0] == "_START_ this is it _STOP_"
Exemplo n.º 3
0
def test_TextEncoder_encode_stopAddedIfMissing():
    from utils import load_data, TextCoder

    data = load_data(["test_data/10k.en"], ["en"], quiet=True)
    coder = TextCoder(data["en"], vocabulary_size=2000)
    encoded = coder.encode(["this is a european session"], one_hot=False)
    assert encoded[0][-1, 0] == coder.word2int["_STOP_"]

    encoded = coder.encode(["this is a european session"], one_hot=True)
    assert encoded[0][-1, :].argmax() == coder.word2int["_STOP_"]
Exemplo n.º 4
0
def test_TextEncoder_encode_startNotAddedIfPresent():
    from utils import load_data, TextCoder

    data = load_data(["test_data/10k.en"], ["en"], quiet=True)
    coder = TextCoder(data["en"], vocabulary_size=2000)
    encoded = coder.encode(["_START_ this is a european session"], one_hot=False)
    assert encoded[0][0, 0] == coder.word2int["_START_"]
    assert encoded[0][1, 0] != coder.word2int["_START_"]

    encoded = coder.encode(["_START_ this is a european session"], one_hot=True)
    assert encoded[0][0, :].argmax() == coder.word2int["_START_"]
    assert encoded[0][1, :].argmax() != coder.word2int["_START_"]
Exemplo n.º 5
0
def drop_to_pudb():
    from pudb import set_trace
    from utils import load_data, TextCoder

    data = load_data(
        # ["data/europarl-v7.sv-en.sv", "data/europarl-v7.sv-en.en"],
        ["test_data/10k.sv", "test_data/10k.en"],
        # ["data/test.sv", "data/test.en"],
        ["sv", "en"],
    )
    sv_coder = TextCoder(data["sv"])
    en_coder = TextCoder(data["en"])
    set_trace()
Exemplo n.º 6
0
def test_TextEncoder_encodeDecode_oneHotAndOrdinalSame():
    from utils import load_data, TextCoder

    data = load_data(["test_data/10k.en"], ["en"], quiet=True)
    coder = TextCoder(data["en"], vocabulary_size=2000)
    onehot_encoded = coder.encode(data["en"].values, one_hot=True)
    onehot_decoded = coder.decode(onehot_encoded, one_hot=True)
    ordinal_encoded = coder.encode(data["en"].values, one_hot=False)
    ordinal_decoded = coder.decode(ordinal_encoded, one_hot=False)
    for onehot, ordinal in zip(onehot_decoded, ordinal_decoded):
        assert onehot == ordinal