def test_avoid_list_batch(global_raw_phrase_list, raw_phrase_list, batch_size, beam_size, prefix, expected_avoid): global_avoid_trie = None if global_raw_phrase_list: global_raw_phrase_list = [list(strids2ids(get_tokens(phrase))) for phrase in global_raw_phrase_list] global_avoid_trie = AvoidTrie(global_raw_phrase_list) avoid_batch = AvoidBatch(batch_size, beam_size, avoid_list=raw_phrase_list, global_avoid_trie=global_avoid_trie) for word_id in strids2ids(get_tokens(prefix)): avoid_batch.consume(mx.nd.array([word_id] * (batch_size * beam_size))) avoid = [(x, y) for x, y in zip(*avoid_batch.avoid())] assert set(avoid) == set(expected_avoid)
def test_avoid_list_state(raw_phrase_list): raw_phrase_list = [list(get_tokens(phrase)) for phrase in raw_phrase_list] root_trie = AvoidTrie(raw_phrase_list) root_state = AvoidState(root_trie) oov_id = 83284 # Consuming an OOV ID from the root state should return the same state assert root_state == root_state.consume(oov_id) # The avoid lists should also be the same assert root_state.avoid() == root_state.consume(oov_id).avoid() root_ids_to_avoid = root_state.avoid() for phrase in raw_phrase_list: state = root_state for word in phrase[:-1]: state = state.consume(word) # The last word of the phrase should be in the avoid list assert phrase[-1] in state.avoid() # Trying to advance on an OOV from inside a multi-word constraint should return a new state if len(phrase) > 1: new_state = state.consume(oov_id) assert new_state != state
def build_vocab(data: Iterable[str], num_words: int = 50000, min_count: int = 1) -> Dict[str, int]: """ Creates a vocabulary mapping from words to ids. Increasing integer ids are assigned by word frequency, using lexical sorting as a tie breaker. The only exception to this are special symbols such as the padding symbol (PAD). :param data: Sequence of sentences containing whitespace delimited tokens. :param num_words: Maximum number of words in the vocabulary. :param min_count: Minimum occurrences of words to be included in the vocabulary. :return: Word-to-id mapping. """ vocab_symbols_set = set(C.VOCAB_SYMBOLS) raw_vocab = Counter(token for line in data for token in get_tokens(line) if token not in vocab_symbols_set) logger.info("Initial vocabulary: %d types" % len(raw_vocab)) # For words with the same count, they will be ordered reverse alphabetically. # Not an issue since we only care for consistency pruned_vocab = sorted(((c, w) for w, c in raw_vocab.items() if c >= min_count), reverse=True) logger.info("Pruned vocabulary: %d types (min frequency %d)", len(pruned_vocab), min_count) vocab = islice((w for c, w in pruned_vocab), num_words) word_to_id = {word: idx for idx, word in enumerate(chain(C.VOCAB_SYMBOLS, vocab))} logger.info("Final vocabulary: %d types (min frequency %d, top %d types)", len(word_to_id), min_count, num_words) # Important: pad symbol becomes index 0 assert word_to_id[C.PAD_SYMBOL] == C.PAD_ID return word_to_id
def test_avoid_list_trie(raw_phrase_list): # Make sure the trie reports the right size raw_phrase_list = [list(get_tokens(phrase)) for phrase in raw_phrase_list] root_trie = AvoidTrie(raw_phrase_list) assert len(root_trie) == len(raw_phrase_list) # The last word of the phrase should be in the final() list after walking through the # first len(phrase) - 1 words for phrase in raw_phrase_list: trie = root_trie for word in phrase[:-1]: trie = trie.step(word) assert phrase[-1] in trie.final() oov_id = 8239 assert root_trie.step(oov_id) is None root_trie.add_phrase([oov_id, 17]) assert root_trie.step(oov_id) is not None assert 17 in root_trie.step(oov_id).final()