예제 #1
0
def test_avoid_list_state(raw_phrase_list):
    raw_phrase_list = [list(get_tokens(phrase)) for phrase in raw_phrase_list]
    root_trie = AvoidTrie(raw_phrase_list)

    root_state = AvoidState(root_trie)
    oov_id = 83284

    # Consuming an OOV ID from the root state should return the same state
    assert root_state == root_state.consume(oov_id)

    # The avoid lists should also be the same
    assert root_state.avoid() == root_state.consume(oov_id).avoid()

    root_ids_to_avoid = root_state.avoid()
    for phrase in raw_phrase_list:
        state = root_state
        for word in phrase[:-1]:
            state = state.consume(word)

        # The last word of the phrase should be in the avoid list
        assert phrase[-1] in state.avoid()

        # Trying to advance on an OOV from inside a multi-word constraint should return a new state
        if len(phrase) > 1:
            new_state = state.consume(oov_id)
            assert new_state != state
예제 #2
0
def test_avoid_list_batch(global_raw_phrase_list, raw_phrase_list, batch_size, beam_size, prefix, expected_avoid):

    global_avoid_trie = None
    if global_raw_phrase_list:
        global_raw_phrase_list = [list(strids2ids(get_tokens(phrase))) for phrase in global_raw_phrase_list]
        global_avoid_trie = AvoidTrie(global_raw_phrase_list)

    avoid_batch = AvoidBatch(batch_size, beam_size, avoid_list=raw_phrase_list, global_avoid_trie=global_avoid_trie)

    for word_id in strids2ids(get_tokens(prefix)):
        avoid_batch.consume(mx.nd.array([word_id] * (batch_size * beam_size)))

    avoid = [(x, y) for x, y in zip(*avoid_batch.avoid())]
    assert set(avoid) == set(expected_avoid)
예제 #3
0
def test_avoid_list_trie(raw_phrase_list):
    # Make sure the trie reports the right size
    raw_phrase_list = [list(get_tokens(phrase)) for phrase in raw_phrase_list]
    root_trie = AvoidTrie(raw_phrase_list)
    assert len(root_trie) == len(raw_phrase_list)

    # The last word of the phrase should be in the final() list after walking through the
    # first len(phrase) - 1 words
    for phrase in raw_phrase_list:
        trie = root_trie
        for word in phrase[:-1]:
            trie = trie.step(word)
        assert phrase[-1] in trie.final()

    oov_id = 8239
    assert root_trie.step(oov_id) is None

    root_trie.add_phrase([oov_id, 17])
    assert root_trie.step(oov_id) is not None
    assert 17 in root_trie.step(oov_id).final()
예제 #4
0
def test_avoid_list_trie(raw_phrase_list):
    # Make sure the trie reports the right size
    raw_phrase_list = [list(get_tokens(phrase)) for phrase in raw_phrase_list]
    root_trie = AvoidTrie(raw_phrase_list)
    assert len(root_trie) == len(raw_phrase_list)

    # The last word of the phrase should be in the final() list after walking through the
    # first len(phrase) - 1 words
    for phrase in raw_phrase_list:
        trie = root_trie
        for word in phrase[:-1]:
            trie = trie.step(word)
        assert phrase[-1] in trie.final()

    oov_id = 8239
    assert root_trie.step(oov_id) is None

    root_trie.add_phrase([oov_id, 17])
    assert root_trie.step(oov_id) is not None
    assert 17 in root_trie.step(oov_id).final()