def test_avoid_list_batch(global_raw_phrase_list, raw_phrase_list, batch_size, beam_size, prefix, expected_avoid): global_avoid_trie = None if global_raw_phrase_list: global_raw_phrase_list = [list(strids2ids(get_tokens(phrase))) for phrase in global_raw_phrase_list] global_avoid_trie = AvoidTrie(global_raw_phrase_list) avoid_batch = AvoidBatch(batch_size, beam_size, avoid_list=raw_phrase_list, global_avoid_trie=global_avoid_trie) for word_id in strids2ids(get_tokens(prefix)): avoid_batch.consume(mx.nd.array([word_id] * (batch_size * beam_size))) avoid = [(x, y) for x, y in zip(*avoid_batch.avoid())] assert set(avoid) == set(expected_avoid)
def test_sequence_reader(sequences, use_vocab, add_bos, add_eos): with TemporaryDirectory() as work_dir: path = os.path.join(work_dir, 'input') with open(path, 'w') as f: for sequence in sequences: print(sequence, file=f) vocabulary = vocab.build_vocab(sequences) if use_vocab else None reader = data_io.SequenceReader(path, vocabulary=vocabulary, add_bos=add_bos, add_eos=add_eos) read_sequences = [s for s in reader] assert len(read_sequences) == len(sequences) if vocabulary is None: with pytest.raises(SockeyeError) as e: _ = data_io.SequenceReader(path, vocabulary=vocabulary, add_bos=True) assert str(e.value) == "Adding a BOS or EOS symbol requires a vocabulary" expected_sequences = [data_io.strids2ids(get_tokens(s)) if s else None for s in sequences] assert read_sequences == expected_sequences else: expected_sequences = [data_io.tokens2ids(get_tokens(s), vocabulary) if s else None for s in sequences] if add_bos: expected_sequences = [[vocabulary[C.BOS_SYMBOL]] + s if s else None for s in expected_sequences] if add_eos: expected_sequences = [s + [vocabulary[C.EOS_SYMBOL]] if s else None for s in expected_sequences] assert read_sequences == expected_sequences
def test_sequence_reader(sequences, use_vocab, add_bos): with TemporaryDirectory() as work_dir: path = os.path.join(work_dir, 'input') with open(path, 'w') as f: for sequence in sequences: f.write(sequence + "\n") vocabulary = vocab.build_vocab(sequences) if use_vocab else None reader = data_io.SequenceReader(path, vocab=vocabulary, add_bos=add_bos) read_sequences = [s for s in reader] assert reader.is_done() assert len(read_sequences) == reader.count if vocabulary is None: with pytest.raises(SockeyeError) as e: _ = data_io.SequenceReader(path, vocab=vocabulary, add_bos=True) assert str(e.value) == "Adding a BOS symbol requires a vocabulary" expected_sequences = [data_io.strids2ids(get_tokens(s)) for s in sequences] assert read_sequences == expected_sequences else: expected_sequences = [data_io.tokens2ids(get_tokens(s), vocabulary) for s in sequences] if add_bos: expected_sequences = [[vocabulary[C.BOS_SYMBOL]] + s for s in expected_sequences] assert read_sequences == expected_sequences # check raise for multiple concurrent iters _ = iter(reader) with pytest.raises(SockeyeError) as e: iter(reader) assert str(e.value) == "Can not iterate multiple times simultaneously."
def test_sequence_reader(sequences, use_vocab, add_bos, add_eos): with TemporaryDirectory() as work_dir: path = os.path.join(work_dir, 'input') with open(path, 'w') as f: for sequence in sequences: print(sequence, file=f) vocabulary = vocab.build_vocab(sequences) if use_vocab else None reader = data_io.SequenceReader(path, vocabulary=vocabulary, add_bos=add_bos, add_eos=add_eos) read_sequences = [s for s in reader] assert len(read_sequences) == len(sequences) if vocabulary is None: with pytest.raises(SockeyeError) as e: data_io.SequenceReader(path, vocabulary=vocabulary, add_bos=True) assert str(e.value) == "Adding a BOS or EOS symbol requires a vocabulary" expected_sequences = [data_io.strids2ids(get_tokens(s)) if s else None for s in sequences] assert read_sequences == expected_sequences else: expected_sequences = [data_io.tokens2ids(get_tokens(s), vocabulary) if s else None for s in sequences] if add_bos: expected_sequences = [[vocabulary[C.BOS_SYMBOL]] + s if s else None for s in expected_sequences] if add_eos: expected_sequences = [s + [vocabulary[C.EOS_SYMBOL]] if s else None for s in expected_sequences] assert read_sequences == expected_sequences
def test_strids2ids(tokens, expected_ids): ids = data_io.strids2ids(tokens) assert ids == expected_ids
def test_strids2ids(tokens, expected_ids): ids = data_io.strids2ids(tokens) assert ids == expected_ids
def test_strids2ids(tokens, expected_ids): pytest.importorskip('mxnet') from sockeye import data_io ids = data_io.strids2ids(tokens) assert ids == expected_ids