示例#1
0
def test_sequence_reader(sequences, use_vocab, add_bos, add_eos):
    with TemporaryDirectory() as work_dir:
        path = os.path.join(work_dir, 'input')
        with open(path, 'w') as f:
            for sequence in sequences:
                print(sequence, file=f)

        vocabulary = vocab.build_vocab(sequences) if use_vocab else None

        reader = data_io.SequenceReader(path, vocabulary=vocabulary, add_bos=add_bos, add_eos=add_eos)

        read_sequences = [s for s in reader]
        assert len(read_sequences) == len(sequences)

        if vocabulary is None:
            with pytest.raises(SockeyeError) as e:
                _ = data_io.SequenceReader(path, vocabulary=vocabulary, add_bos=True)
            assert str(e.value) == "Adding a BOS or EOS symbol requires a vocabulary"

            expected_sequences = [data_io.strids2ids(get_tokens(s)) if s else None for s in sequences]
            assert read_sequences == expected_sequences
        else:
            expected_sequences = [data_io.tokens2ids(get_tokens(s), vocabulary) if s else None for s in sequences]
            if add_bos:
                expected_sequences = [[vocabulary[C.BOS_SYMBOL]] + s if s else None for s in expected_sequences]
            if add_eos:
                expected_sequences = [s + [vocabulary[C.EOS_SYMBOL]]  if s else None for s in expected_sequences]
            assert read_sequences == expected_sequences
示例#2
0
def test_sequence_reader(sequences, use_vocab, add_bos):
    with TemporaryDirectory() as work_dir:
        path = os.path.join(work_dir, 'input')
        with open(path, 'w') as f:
            for sequence in sequences:
                f.write(sequence + "\n")

        vocabulary = vocab.build_vocab(sequences) if use_vocab else None

        reader = data_io.SequenceReader(path, vocab=vocabulary, add_bos=add_bos)

        read_sequences = [s for s in reader]
        assert reader.is_done()
        assert len(read_sequences) == reader.count

        if vocabulary is None:
            with pytest.raises(SockeyeError) as e:
                _ = data_io.SequenceReader(path, vocab=vocabulary, add_bos=True)
            assert str(e.value) == "Adding a BOS symbol requires a vocabulary"

            expected_sequences = [data_io.strids2ids(get_tokens(s)) for s in sequences]
            assert read_sequences == expected_sequences
        else:
            expected_sequences = [data_io.tokens2ids(get_tokens(s), vocabulary) for s in sequences]
            if add_bos:
                expected_sequences = [[vocabulary[C.BOS_SYMBOL]] + s for s in expected_sequences]
            assert read_sequences == expected_sequences

        # check raise for multiple concurrent iters
        _ = iter(reader)
        with pytest.raises(SockeyeError) as e:
            iter(reader)
        assert str(e.value) == "Can not iterate multiple times simultaneously."
示例#3
0
def test_sequence_reader(sequences, use_vocab, add_bos, add_eos):
    with TemporaryDirectory() as work_dir:
        path = os.path.join(work_dir, 'input')
        with open(path, 'w') as f:
            for sequence in sequences:
                print(sequence, file=f)

        vocabulary = vocab.build_vocab(sequences) if use_vocab else None

        reader = data_io.SequenceReader(path, vocabulary=vocabulary, add_bos=add_bos, add_eos=add_eos)

        read_sequences = [s for s in reader]
        assert len(read_sequences) == len(sequences)

        if vocabulary is None:
            with pytest.raises(SockeyeError) as e:
                data_io.SequenceReader(path, vocabulary=vocabulary, add_bos=True)
            assert str(e.value) == "Adding a BOS or EOS symbol requires a vocabulary"

            expected_sequences = [data_io.strids2ids(get_tokens(s)) if s else None for s in sequences]
            assert read_sequences == expected_sequences
        else:
            expected_sequences = [data_io.tokens2ids(get_tokens(s), vocabulary) if s else None for s in sequences]
            if add_bos:
                expected_sequences = [[vocabulary[C.BOS_SYMBOL]] + s if s else None for s in expected_sequences]
            if add_eos:
                expected_sequences = [s + [vocabulary[C.EOS_SYMBOL]] if s else None for s in expected_sequences]
            assert read_sequences == expected_sequences
示例#4
0
def test_padded_build_vocab(num_types, pad_to_multiple_of,
                            expected_vocab_size):
    data = [" ".join('word%d' % i for i in range(num_types))]
    size = None
    min_count = 1
    vocab = build_vocab(data,
                        size,
                        min_count,
                        pad_to_multiple_of=pad_to_multiple_of)
    assert len(vocab) == expected_vocab_size
示例#5
0
def test_constants_in_vocab(data, size, min_count, constants):
    vocab = build_vocab(data, size, min_count)
    for const in constants:
        assert const in vocab
示例#6
0
def test_build_vocab(data, size, min_count, expected):
    vocab = build_vocab(data, size, min_count)
    assert vocab == expected
示例#7
0
def test_build_vocab(data, size, min_count, expected):
    vocab = build_vocab(data=data, num_words=size, min_count=min_count)
    assert vocab == expected
示例#8
0
def test_constants_in_vocab(data, size, min_count, constants):
    vocab = build_vocab(data, size, min_count)
    for const in constants:
        assert const in vocab
示例#9
0
def test_padded_build_vocab(num_types, pad_to_multiple_of, expected_vocab_size):
    data = [" ".join('word%d' % i for i in range(num_types))]
    size = None
    min_count = 1
    vocab = build_vocab(data, size, min_count, pad_to_multiple_of=pad_to_multiple_of)
    assert len(vocab) == expected_vocab_size
示例#10
0
def test_build_vocab(data, size, min_count, expected):
    vocab = build_vocab(data=data, num_words=size, min_count=min_count)
    assert vocab == expected