def test_avoid_list_state(raw_phrase_list): raw_phrase_list = [list(get_tokens(phrase)) for phrase in raw_phrase_list] root_trie = AvoidTrie(raw_phrase_list) root_state = AvoidState(root_trie) oov_id = 83284 # Consuming an OOV ID from the root state should return the same state assert root_state == root_state.consume(oov_id) # The avoid lists should also be the same assert root_state.avoid() == root_state.consume(oov_id).avoid() root_ids_to_avoid = root_state.avoid() for phrase in raw_phrase_list: state = root_state for word in phrase[:-1]: state = state.consume(word) # The last word of the phrase should be in the avoid list assert phrase[-1] in state.avoid() # Trying to advance on an OOV from inside a multi-word constraint should return a new state if len(phrase) > 1: new_state = state.consume(oov_id) assert new_state != state
def test_avoid_list_batch(global_raw_phrase_list, raw_phrase_list, batch_size, beam_size, prefix, expected_avoid): global_avoid_trie = None if global_raw_phrase_list: global_raw_phrase_list = [list(strids2ids(get_tokens(phrase))) for phrase in global_raw_phrase_list] global_avoid_trie = AvoidTrie(global_raw_phrase_list) avoid_batch = AvoidBatch(batch_size, beam_size, avoid_list=raw_phrase_list, global_avoid_trie=global_avoid_trie) for word_id in strids2ids(get_tokens(prefix)): avoid_batch.consume(mx.nd.array([word_id] * (batch_size * beam_size))) avoid = [(x, y) for x, y in zip(*avoid_batch.avoid())] assert set(avoid) == set(expected_avoid)
def test_avoid_list_trie(raw_phrase_list): # Make sure the trie reports the right size raw_phrase_list = [list(get_tokens(phrase)) for phrase in raw_phrase_list] root_trie = AvoidTrie(raw_phrase_list) assert len(root_trie) == len(raw_phrase_list) # The last word of the phrase should be in the final() list after walking through the # first len(phrase) - 1 words for phrase in raw_phrase_list: trie = root_trie for word in phrase[:-1]: trie = trie.step(word) assert phrase[-1] in trie.final() oov_id = 8239 assert root_trie.step(oov_id) is None root_trie.add_phrase([oov_id, 17]) assert root_trie.step(oov_id) is not None assert 17 in root_trie.step(oov_id).final()