コード例 #1
0
ファイル: test_data_io.py プロジェクト: ye-kyaw-thu/sockeye
def test_calculate_length_statistics(sources, target, expected_num_sents,
                                     expected_mean, expected_std):
    length_statistics = data_io.calculate_length_statistics(
        sources, target, 5, 5)
    assert len(sources[0]) == len(target)
    assert length_statistics.num_sents == expected_num_sents
    assert np.isclose(length_statistics.length_ratio_mean, expected_mean)
    assert np.isclose(length_statistics.length_ratio_std, expected_std)
コード例 #2
0
def test_calculate_length_statistics(sources, targets, expected_num_sents, expected_mean, expected_std):
    pytest.importorskip('mxnet')
    from sockeye import data_io
    from mxnet import np
    length_statistics = data_io.calculate_length_statistics(sources, targets, 5, 5)
    assert len(sources[0]) == len(targets[0])
    assert length_statistics.num_sents == expected_num_sents
    assert np.isclose(length_statistics.length_ratio_mean, expected_mean)
    assert np.isclose(length_statistics.length_ratio_std, expected_std)
コード例 #3
0
def test_non_parallel_calculate_length_statistics(sources, target):
    with pytest.raises(SockeyeError):
        data_io.calculate_length_statistics(sources, target, 5, 5)
コード例 #4
0
ファイル: test_data_io.py プロジェクト: lagka/sockeye
def test_non_parallel_calculate_length_statistics(sources, target):
    with pytest.raises(SockeyeError):
        data_io.calculate_length_statistics(sources, target, 5, 5)
コード例 #5
0
ファイル: test_data_io.py プロジェクト: lagka/sockeye
def test_calculate_length_statistics(sources, target, expected_num_sents, expected_mean, expected_std):
    length_statistics = data_io.calculate_length_statistics(sources, target, 5, 5)
    assert len(sources[0]) == len(target)
    assert length_statistics.num_sents == expected_num_sents
    assert np.isclose(length_statistics.length_ratio_mean, expected_mean)
    assert np.isclose(length_statistics.length_ratio_std, expected_std)
コード例 #6
0
def test_non_parallel_calculate_length_statistics(sources, targets):
    pytest.importorskip('mxnet')
    from sockeye import data_io
    with pytest.raises(SockeyeError):
        data_io.calculate_length_statistics(sources, targets, 5, 5)