Esempio n. 1
0
def test_calculate_length_statistics(sources, target, expected_num_sents,
                                     expected_mean, expected_std):
    length_statistics = data_io.calculate_length_statistics(
        sources, target, 5, 5)
    assert len(sources[0]) == len(target)
    assert length_statistics.num_sents == expected_num_sents
    assert np.isclose(length_statistics.length_ratio_mean, expected_mean)
    assert np.isclose(length_statistics.length_ratio_std, expected_std)
Esempio n. 2
0
def test_calculate_length_statistics(sources, targets, expected_num_sents, expected_mean, expected_std):
    pytest.importorskip('mxnet')
    from sockeye import data_io
    from mxnet import np
    length_statistics = data_io.calculate_length_statistics(sources, targets, 5, 5)
    assert len(sources[0]) == len(targets[0])
    assert length_statistics.num_sents == expected_num_sents
    assert np.isclose(length_statistics.length_ratio_mean, expected_mean)
    assert np.isclose(length_statistics.length_ratio_std, expected_std)
Esempio n. 3
0
def test_non_parallel_calculate_length_statistics(sources, target):
    with pytest.raises(SockeyeError):
        data_io.calculate_length_statistics(sources, target, 5, 5)
Esempio n. 4
0
def test_non_parallel_calculate_length_statistics(sources, target):
    with pytest.raises(SockeyeError):
        data_io.calculate_length_statistics(sources, target, 5, 5)
Esempio n. 5
0
def test_calculate_length_statistics(sources, target, expected_num_sents, expected_mean, expected_std):
    length_statistics = data_io.calculate_length_statistics(sources, target, 5, 5)
    assert len(sources[0]) == len(target)
    assert length_statistics.num_sents == expected_num_sents
    assert np.isclose(length_statistics.length_ratio_mean, expected_mean)
    assert np.isclose(length_statistics.length_ratio_std, expected_std)
Esempio n. 6
0
def test_non_parallel_calculate_length_statistics(sources, targets):
    pytest.importorskip('mxnet')
    from sockeye import data_io
    with pytest.raises(SockeyeError):
        data_io.calculate_length_statistics(sources, targets, 5, 5)