コード例 #1
0
def test_beam_width_changes():
    def the_lm(s):
        if s == 'A':
            return 0.5
        return 1

    labels = ['_', 'A', ' ']
    samples = np.array([[0.8, 0.2, 0], [0.7, 0.3, 0], [0.6, 0.4, 0],
                        [0.0, 0.0, 1]])
    res = prefix_beam_search(samples,
                             labels,
                             lm=the_lm,
                             return_weights=False,
                             k=25,
                             alpha=1,
                             beta=0)
    res2 = prefix_beam_search(samples,
                              labels,
                              lm=the_lm,
                              return_weights=False,
                              k=1,
                              alpha=1,
                              beta=0)

    assert res == ' '
    assert res2 == 'A '
コード例 #2
0
def test_sanity():
    sample = np.zeros((10, len(english_labels)))
    sample[0, 2] = 0.5
    sample[1, 20] = 0.5
    sample[2, 19] = 0.5
    sample[3:, 0] = 0.5
    res = prefix_beam_search(sample, english_labels)
    assert res == 'ASR'
コード例 #3
0
def test_beam_is_not_greedy():
    '''
    Example from https://towardsdatascience.com/beam-search-decoding-in-ctc-trained-neural-networks-51889a3d85a7
    Shows that beam search can find a path that greedy decoding can not.
    '''
    labels = ['_', 'A', 'B']
    samples = np.array([0.8, 0.2, 0], [0.6, 0.4, 0])
    res = prefix_beam_search(samples,
                             labels,
                             blank_index=0,
                             return_weights=True)
    assert res == ('A', 0.52)
    greedy_res = greedy_decode(samples, labels)
    assert greedy_res == ''
コード例 #4
0
def test_beam_is_not_greedy():
    '''
    Example from https://towardsdatascience.com/beam-search-decoding-in-ctc-trained-neural-networks-51889a3d85a7
    Shows that beam search can find a path that greedy decoding can not.
    '''
    labels = ['_', 'A', 'B', ' ']
    samples = np.array([[0.8, 0.2, 0, 0], [0.6, 0.4, 0, 0]])
    res = prefix_beam_search(samples,
                             labels,
                             blank_index=0,
                             return_weights=True)
    assert res == ('A', 0.52)

    greedy_decoder = GreedyDecoder(labels, blank_index=0)
    greedy_res = greedy_decoder.decode(torch.FloatTensor(samples).unsqueeze(0),
                                       sizes=None)
    assert greedy_res == ['']
コード例 #5
0
def test_inconsistent_sizes():
    sample = np.zeros((10, len(english_labels) - 1))
    with pytest.raises(AssertionError) as exc_info:
        _ = prefix_beam_search(sample, english_labels)
    assert exc_info is not None